Duibonduil commited on
Commit
ae64487
·
verified ·
1 Parent(s): 3193e62

Upload 10 files

Browse files
aworld/models/README.md ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AWorld LLM Interface
2
+
3
+ A unified interface for interacting with various LLM providers through a consistent API.
4
+
5
+ ## Features
6
+
7
+ - Unified API for multiple LLM providers. Currently, only OpenAI and Anthropic are supported.
8
+ - Synchronous and asynchronous calls with optional initialization control
9
+ - Streaming responses support
10
+ - Tool calls support
11
+ - Unified ModelResponse object for all provider responses
12
+ - Easy extension with custom providers
13
+
14
+ ## Supported Providers
15
+
16
+ - `openai`: Models supporting OpenAI API protocol (OpenAI, compatible models)
17
+ - `anthropic`: Models supporting Anthropic API protocol (Claude models)
18
+ - `azure_openai`: Azure OpenAI service
19
+
20
+ ## Basic Usage
21
+
22
+ ### Quick Start
23
+
24
+ ```python
25
+ from aworld.config.conf import AgentConfig
26
+ from aworld.models.llm import get_llm_model, call_llm_model, acall_llm_model
27
+
28
+ # Create configuration
29
+ config = AgentConfig(
30
+ llm_provider="openai", # Options: "openai", "anthropic", "azure_openai"
31
+ llm_model_name="gpt-4o",
32
+ llm_temperature=0.0,
33
+ llm_api_key="your_api_key",
34
+ llm_base_url="your_llm_server_address"
35
+ )
36
+
37
+ # Initialize the model
38
+ model = get_llm_model(config)
39
+
40
+ # Prepare messages
41
+ messages = [
42
+ {"role": "system", "content": "You are a helpful AI assistant."},
43
+ {"role": "user", "content": "Explain Python in three sentences."}
44
+ ]
45
+
46
+ # Get response
47
+ response = model.completion(messages)
48
+ print(response.content) # Access content directly from ModelResponse
49
+ ```
50
+
51
+ ### Using call_llm_model (Recommended)
52
+
53
+ ```python
54
+ from aworld.models.llm import get_llm_model, call_llm_model
55
+
56
+ # Initialize model
57
+ model = get_llm_model(
58
+ llm_provider="openai",
59
+ model_name="gpt-4o",
60
+ api_key="your_api_key",
61
+ base_url="https://api.openai.com/v1"
62
+ )
63
+
64
+ # Prepare messages
65
+ messages = [
66
+ {"role": "system", "content": "You are a helpful AI assistant."},
67
+ {"role": "user", "content": "Write a short poem about programming."}
68
+ ]
69
+
70
+ # Using call_llm_model - returns ModelResponse object
71
+ response = call_llm_model(model, messages)
72
+ print(response.content) # Access content directly from ModelResponse
73
+
74
+ # Stream response with call_llm_model
75
+ for chunk in call_llm_model(model, messages, temperature=0.7, stream=True):
76
+ if chunk.content:
77
+ print(chunk.content, end="", flush=True)
78
+ ```
79
+
80
+ ### Asynchronous Calls with acall_llm_model
81
+
82
+ ```python
83
+ import asyncio
84
+ from aworld.models.llm import get_llm_model, acall_llm_model
85
+
86
+ async def main():
87
+ # Initialize model
88
+ model = get_llm_model(
89
+ llm_provider="anthropic",
90
+ model_name="claude-3-5-sonnet-20241022",
91
+ api_key="your_anthropic_api_key"
92
+ )
93
+
94
+ # Prepare messages
95
+ messages = [
96
+ {"role": "user", "content": "List 3 effective ways to learn programming."}
97
+ ]
98
+
99
+ # Async call with acall_llm_model
100
+ response = await acall_llm_model(model, messages)
101
+ print(response.content)
102
+
103
+ # Async streaming with acall_llm_model
104
+ print("\nStreaming response:")
105
+ async for chunk in await acall_llm_model(model, messages, stream=True):
106
+ if chunk.content:
107
+ print(chunk.content, end="", flush=True)
108
+
109
+ # Run async function
110
+ asyncio.run(main())
111
+ ```
112
+
113
+ ### Selective Sync/Async Initialization
114
+
115
+ For performance optimization, you can control whether to initialize synchronous or asynchronous providers:
116
+ By default, both `sync_enabled` and `async_enabled` are set to `True`, which means both synchronous and asynchronous providers will be initialized.
117
+
118
+ ```python
119
+ # Initialize only synchronous provider
120
+ model = get_llm_model(
121
+ llm_provider="openai",
122
+ model_name="gpt-4o",
123
+ sync_enabled=True, # Initialize sync provider
124
+ async_enabled=False, # Don't initialize async provider
125
+ api_key="your_api_key"
126
+ )
127
+
128
+ # Initialize only asynchronous provider
129
+ model = get_llm_model(
130
+ llm_provider="anthropic",
131
+ model_name="claude-3-5-sonnet-20241022",
132
+ sync_enabled=False, # Don't initialize sync provider
133
+ async_enabled=True, # Initialize async provider
134
+ api_key="your_api_key"
135
+ )
136
+
137
+ # Initialize both (default behavior)
138
+ model = get_llm_model(
139
+ llm_provider="openai",
140
+ model_name="gpt-4o",
141
+ sync_enabled=True,
142
+ async_enabled=True
143
+ )
144
+ ```
145
+
146
+ ### HTTP Client Mode
147
+
148
+ You can use direct HTTP requests instead of the SDK by specifying `client_type=ClientType.HTTP` parameter:
149
+
150
+ ```python
151
+ from aworld.config.conf import AgentConfig, ClientType
152
+ from aworld.models.llm import get_llm_model, call_llm_model
153
+
154
+ # Initialize model with HTTP client mode
155
+ model = get_llm_model(
156
+ llm_provider="openai",
157
+ model_name="gpt-4o",
158
+ api_key="your_api_key",
159
+ base_url="https://api.openai.com/v1",
160
+ client_type=ClientType.HTTP # Use HTTP client instead of SDK
161
+ )
162
+
163
+ # Use it exactly the same way as SDK mode
164
+ messages = [
165
+ {"role": "system", "content": "You are a helpful AI assistant."},
166
+ {"role": "user", "content": "Tell me a short joke."}
167
+ ]
168
+
169
+ # The model uses HTTP requests under the hood
170
+ response = call_llm_model(model, messages)
171
+ print(response.content)
172
+
173
+ # Streaming also works with HTTP client
174
+ for chunk in call_llm_model(model, messages, stream=True):
175
+ if chunk.content:
176
+ print(chunk.content, end="", flush=True)
177
+ ```
178
+
179
+ This approach can be useful when:
180
+ - You need more control over the HTTP requests
181
+ - You have compatibility issues with the official SDK
182
+ - You're using a model that follows OpenAI API protocol but isn't fully compatible with the SDK
183
+
184
+ ### Tool Calls Support
185
+
186
+ ```python
187
+ from aworld.models.llm import get_llm_model, call_llm_model
188
+ import json
189
+
190
+ # Initialize model
191
+ model = get_llm_model(
192
+ llm_provider="openai",
193
+ model_name="gpt-4o",
194
+ api_key="your_api_key"
195
+ )
196
+
197
+ # Define tools
198
+ tools = [
199
+ {
200
+ "type": "function",
201
+ "function": {
202
+ "name": "get_weather",
203
+ "description": "Get the current weather in a given location",
204
+ "parameters": {
205
+ "type": "object",
206
+ "properties": {
207
+ "location": {
208
+ "type": "string",
209
+ "description": "The city and state, e.g. San Francisco, CA"
210
+ }
211
+ },
212
+ "required": ["location"]
213
+ }
214
+ }
215
+ }
216
+ ]
217
+
218
+ # Prepare messages
219
+ messages = [
220
+ {"role": "user", "content": "What's the weather like in San Francisco?"}
221
+ ]
222
+
223
+ # Call model with tools
224
+ response = call_llm_model(model, messages, tools=tools, tool_choice="auto")
225
+
226
+ # Check for tool calls
227
+ if response.tool_calls:
228
+ for tool_call in response.tool_calls:
229
+ print(f"Tool name: {tool_call.name}")
230
+ print(f"Arguments: {tool_call.arguments}")
231
+
232
+ # Handle tool call
233
+ if tool_call.name == "get_weather":
234
+ # Parse arguments
235
+ args = json.loads(tool_call.arguments)
236
+ location = args.get("location")
237
+
238
+ # Mock getting weather data
239
+ weather = "Sunny, 25°C"
240
+
241
+ # Add tool response to messages
242
+ messages.append(response.message) # Add assistant message
243
+ messages.append({
244
+ "role": "tool",
245
+ "tool_call_id": tool_call.id,
246
+ "name": tool_call.name,
247
+ "content": f"{{\"weather\": \"{weather}\"}}"
248
+ })
249
+
250
+ # Call model again
251
+ final_response = call_llm_model(model, messages)
252
+ print("\nFinal response:", final_response.content)
253
+ else:
254
+ print("\nResponse content:", response.content)
255
+ ```
256
+
257
+ ### Asynchronous Calls
258
+
259
+ ```python
260
+ import asyncio
261
+ from aworld.models.llm import get_llm_model
262
+
263
+ async def main():
264
+ # Initialize model
265
+ model = get_llm_model(
266
+ llm_provider="anthropic",
267
+ model_name="claude-3-5-sonnet-20241022",
268
+ temperature=0.0
269
+ )
270
+
271
+ # Prepare messages
272
+ messages = [
273
+ {"role": "user", "content": "Explain machine learning briefly."}
274
+ ]
275
+
276
+ # Async call
277
+ response = await model.acompletion(messages)
278
+ print(response.content)
279
+
280
+ # Run async function
281
+ asyncio.run(main())
282
+ ```
283
+
284
+ ### Streaming Responses
285
+
286
+ ```python
287
+ # Synchronous streaming
288
+ for chunk in model.stream_completion(messages):
289
+ print(chunk.content, end="", flush=True)
290
+
291
+ # Asynchronous streaming
292
+ async for chunk in model.astream_completion(messages):
293
+ print(chunk.content, end="", flush=True)
294
+ ```
295
+
296
+ ## ModelResponse Object
297
+
298
+ All responses are encapsulated in a unified `ModelResponse` object with these key attributes:
299
+
300
+ - `id`: Response ID
301
+ - `model`: Model name used
302
+ - `content`: Generated text content
303
+ - `tool_calls`: List of tool calls (if any)
304
+ - `usage`: Token usage statistics
305
+ - `error`: Error message (if any)
306
+ - `message`: Complete message object for subsequent API calls
307
+
308
+ Example:
309
+ ```python
310
+ response = call_llm_model(model, messages)
311
+ print(f"Content: {response.content}")
312
+ print(f"Model: {response.model}")
313
+ print(f"Total tokens: {response.usage['total_tokens']}")
314
+
315
+ # Get complete message for next call
316
+ messages.append(response.message)
317
+ ```
318
+
319
+ ## API Parameters
320
+
321
+ Essential parameters for model calls:
322
+
323
+ - `messages`: List of message dictionaries with `role` and `content` keys
324
+ - `temperature`: Controls response randomness (0.0-1.0)
325
+ - `max_tokens`: Maximum tokens to generate
326
+ - `stop`: List of stopping sequences
327
+ - `tools`: List of tool definitions
328
+ - `tool_choice`: Tool choice strategy
329
+
330
+ ## Automatic Provider Detection
331
+
332
+ The system can automatically identify the provider based on model name or API endpoint:
333
+
334
+ ```python
335
+ # Detect Anthropic based on model name
336
+ model = get_llm_model(model_name="claude-3-5-sonnet-20241022")
337
+
338
+ ```
339
+
340
+ ## Creating Custom Providers
341
+
342
+ Implement your own provider by extending `LLMProviderBase`:
343
+
344
+ ```python
345
+ from aworld.models.llm import LLMProviderBase, register_llm_provider
346
+ from aworld.models.model_response import ModelResponse, ToolCall
347
+
348
+ class CustomProvider(LLMProviderBase):
349
+ def _init_provider(self):
350
+ # Initialize your API client
351
+ return {
352
+ "api_key": self.api_key,
353
+ "endpoint": self.base_url
354
+ }
355
+
356
+ def _init_async_provider(self):
357
+ # Initialize your asynchronous API client (optional)
358
+ # If not implemented, async methods will raise NotImplementedError
359
+ return None
360
+
361
+ def preprocess_messages(self, messages):
362
+ # Convert standard format to your API format
363
+ return messages
364
+
365
+ def postprocess_response(self, response):
366
+ # Convert API response to ModelResponse
367
+ return ModelResponse(
368
+ id="response_id",
369
+ model=self.model_name,
370
+ content=response.get("text", ""),
371
+ tool_calls=None # Parse ToolCall objects if supported
372
+ )
373
+
374
+ def completion(self, messages, temperature=0.0, **kwargs):
375
+ # Implement the actual API call
376
+ processed = self.preprocess_messages(messages)
377
+ # Call your API here...
378
+ response = {"text": "Response from custom provider"}
379
+ return self.postprocess_response(response)
380
+
381
+ async def acompletion(self, messages, temperature=0.0, **kwargs):
382
+ # Implement async API call
383
+ # Similar to completion but asynchronous
384
+ response = {"text": "Async response from custom provider"}
385
+ return self.postprocess_response(response)
386
+
387
+ # Register your provider
388
+ register_llm_provider("custom_provider", CustomProvider)
389
+
390
+ # Use it like any other provider
391
+ model = get_llm_model(llm_provider="custom_provider", model_name="custom-model")
392
+ ```
393
+
394
+ ## API Key Management
395
+
396
+ Keys are retrieved in this order:
397
+ 1. Direct `api_key` parameter
398
+ 2. Environment variable in `.env` file
399
+ 3. System environment variable
400
+
401
+ Example for OpenAI: `OPENAI_API_KEY` in parameters → `.env` → system env
aworld/models/ant_provider.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import asyncio
3
+ import datetime
4
+ import html
5
+ import json
6
+ import os
7
+ import time
8
+
9
+ from typing import (
10
+ Any,
11
+ List,
12
+ Dict,
13
+ Generator,
14
+ AsyncGenerator,
15
+ )
16
+ from binascii import b2a_hex
17
+
18
+ from aworld.config.conf import ClientType
19
+ from aworld.core.llm_provider_base import LLMProviderBase
20
+ from aworld.models.llm_http_handler import LLMHTTPHandler
21
+ from aworld.models.model_response import ModelResponse, LLMResponseError, ToolCall
22
+ from aworld.logs.util import logger
23
+ from aworld.utils import import_package
24
+ from aworld.models.utils import usage_process
25
+
26
+ MODEL_NAMES = {
27
+ "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
28
+ "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"],
29
+ }
30
+
31
+
32
+ # Custom JSON encoder to handle ToolCall and other special types
33
+ class CustomJSONEncoder(json.JSONEncoder):
34
+ """Custom JSON encoder to handle ToolCall objects and other special types."""
35
+
36
+ def default(self, obj):
37
+ # Handle objects with to_dict method
38
+ if hasattr(obj, 'to_dict') and callable(obj.to_dict):
39
+ return obj.to_dict()
40
+
41
+ # Handle objects with __dict__ attribute (most custom classes)
42
+ if hasattr(obj, '__dict__'):
43
+ return obj.__dict__
44
+
45
+ # Let the base class handle it (will raise TypeError if not serializable)
46
+ return super().default(obj)
47
+
48
+
49
+ class AntProvider(LLMProviderBase):
50
+ """Ant provider implementation.
51
+ """
52
+
53
+ def _init_provider(self):
54
+ """Initialize Ant provider.
55
+
56
+ Returns:
57
+ Ant provider instance.
58
+ """
59
+ import_package("Crypto", install_name="pycryptodome")
60
+
61
+ # Get API key
62
+ api_key = self.api_key
63
+
64
+ if not api_key:
65
+ env_var = "ANT_API_KEY"
66
+ api_key = os.getenv(env_var, "")
67
+ self.api_key = api_key
68
+ if not api_key:
69
+ raise ValueError(
70
+ f"ANT API key not found, please set {env_var} environment variable or provide it in the parameters")
71
+
72
+ if api_key and api_key.startswith("ak_info:"):
73
+ ak_info_str = api_key[len("ak_info:"):]
74
+ try:
75
+ ak_info = json.loads(ak_info_str)
76
+ for key, value in ak_info.items():
77
+ os.environ[key] = value
78
+ if key == "ANT_API_KEY":
79
+ api_key = value
80
+ self.api_key = api_key
81
+ except Exception as e:
82
+ logger.warn(f"Invalid ANT API key startswith ak_info: {api_key}")
83
+
84
+ self.stream_api_key = os.getenv("ANT_STREAM_API_KEY", "")
85
+
86
+ base_url = self.base_url
87
+ if not base_url:
88
+ base_url = os.getenv("ANT_ENDPOINT", "https://zdfmng.alipay.com")
89
+ self.base_url = base_url
90
+
91
+ self.aes_key = os.getenv("ANT_AES_KEY", "")
92
+
93
+ self.is_http_provider = True
94
+ self.kwargs["client_type"] = ClientType.HTTP
95
+ logger.info(f"Using HTTP provider for Ant")
96
+ self.http_provider = LLMHTTPHandler(
97
+ base_url=base_url,
98
+ api_key=api_key,
99
+ model_name=self.model_name,
100
+ )
101
+ self.is_http_provider = True
102
+ return self.http_provider
103
+
104
+ def _init_async_provider(self):
105
+ """Initialize async Ant provider.
106
+
107
+ Returns:
108
+ Async Ant provider instance.
109
+ """
110
+ # Get API key
111
+ if not self.provider:
112
+ provider = self._init_provider()
113
+ return provider
114
+
115
+ @classmethod
116
+ def supported_models(cls) -> list[str]:
117
+ return [""]
118
+
119
+ def _aes_encrypt(self, data, key):
120
+ """AES encryption function. If data is not a multiple of 16 [encrypted data must be a multiple of 16!], pad it to a multiple of 16.
121
+
122
+ Args:
123
+ key: Encryption key
124
+ data: Data to encrypt
125
+
126
+ Returns:
127
+ Encrypted data
128
+ """
129
+ from Crypto.Cipher import AES
130
+
131
+ iv = "1234567890123456"
132
+ cipher = AES.new(key.encode('utf-8'), AES.MODE_CBC, iv.encode('utf-8'))
133
+ block_size = AES.block_size
134
+
135
+ # Check if data is a multiple of 16, if not, pad with b'\0'
136
+ if len(data) % block_size != 0:
137
+ add = block_size - (len(data) % block_size)
138
+ else:
139
+ add = 0
140
+ data = data.encode('utf-8') + b'\0' * add
141
+ encrypted = cipher.encrypt(data)
142
+ result = b2a_hex(encrypted)
143
+ return result.decode('utf-8')
144
+
145
+ def _build_openai_params(self,
146
+ messages: List[Dict[str, str]],
147
+ temperature: float = 0.0,
148
+ max_tokens: int = None,
149
+ stop: List[str] = None,
150
+ **kwargs) -> Dict[str, Any]:
151
+ openai_params = {
152
+ "model": kwargs.get("model_name", self.model_name or ""),
153
+ "messages": messages,
154
+ "temperature": temperature,
155
+ "max_tokens": max_tokens,
156
+ "stop": stop
157
+ }
158
+
159
+ supported_params = [
160
+ "frequency_penalty", "logit_bias", "logprobs", "top_logprobs",
161
+ "presence_penalty", "response_format", "seed", "stream", "top_p",
162
+ "user", "function_call", "functions", "tools", "tool_choice"
163
+ ]
164
+
165
+ for param in supported_params:
166
+ if param in kwargs:
167
+ openai_params[param] = kwargs[param]
168
+
169
+ return openai_params
170
+
171
+ def _build_claude_params(self,
172
+ messages: List[Dict[str, str]],
173
+ temperature: float = 0.0,
174
+ max_tokens: int = None,
175
+ stop: List[str] = None,
176
+ **kwargs) -> Dict[str, Any]:
177
+ claude_params = {
178
+ "model": kwargs.get("model_name", self.model_name or ""),
179
+ "messages": messages,
180
+ "temperature": temperature,
181
+ "max_tokens": max_tokens,
182
+ "stop": stop
183
+ }
184
+
185
+ supported_params = [
186
+ "top_p", "top_k", "reasoning_effort", "tools", "tool_choice"
187
+ ]
188
+
189
+ for param in supported_params:
190
+ if param in kwargs:
191
+ claude_params[param] = kwargs[param]
192
+
193
+ return claude_params
194
+
195
+ def _get_visit_info(self):
196
+ visit_info = {
197
+ "visitDomain": self.kwargs.get("ant_visit_domain") or os.getenv("ANT_VISIT_DOMAIN", "BU_general"),
198
+ "visitBiz": self.kwargs.get("ant_visit_biz") or os.getenv("ANT_VISIT_BIZ", ""),
199
+ "visitBizLine": self.kwargs.get("ant_visit_biz_line") or os.getenv("ANT_VISIT_BIZ_LINE", "")
200
+ }
201
+ if not visit_info["visitBiz"] or not visit_info["visitBizLine"]:
202
+ return None
203
+ return visit_info
204
+
205
+ def _get_service_param(self,
206
+ message_key: str,
207
+ output_type: str = "request",
208
+ messages: List[Dict[str, str]] = None,
209
+ temperature: float = 0.0,
210
+ max_tokens: int = None,
211
+ stop: List[str] = None,
212
+ **kwargs
213
+ ) -> Dict[str, Any]:
214
+ """Get service name from model name.
215
+ Returns:
216
+ Service name.
217
+ """
218
+ if messages:
219
+ for message in messages:
220
+ if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]:
221
+ if message["content"] is None: message["content"] = ""
222
+ processed_tool_calls = []
223
+ for tool_call in message["tool_calls"]:
224
+ if isinstance(tool_call, dict):
225
+ processed_tool_calls.append(tool_call)
226
+ elif isinstance(tool_call, ToolCall):
227
+ processed_tool_calls.append(tool_call.to_dict())
228
+ message["tool_calls"] = processed_tool_calls
229
+ query_conditions = {
230
+ "messageKey": message_key,
231
+ }
232
+ param = {"cacheInterval": -1, }
233
+ visit_info = self._get_visit_info()
234
+ if not visit_info:
235
+ raise LLMResponseError(
236
+ f"AntProvider#Invalid visit_info, please set ANT_VISIT_BIZ and ANT_VISIT_BIZ_LINE environment variable or provide it in the parameters",
237
+ self.model_name or "unknown"
238
+ )
239
+ param.update(visit_info)
240
+ if self.model_name.startswith("claude"):
241
+ query_conditions.update(self._build_claude_params(messages, temperature, max_tokens, stop, **kwargs))
242
+ param.update({
243
+ "serviceName": "amazon_claude_chat_completions_dataview",
244
+ "queryConditions": query_conditions,
245
+ })
246
+ elif output_type == "pull":
247
+ param.update({
248
+ "serviceName": "chatgpt_response_query_dataview",
249
+ "queryConditions": query_conditions
250
+ })
251
+ else:
252
+ query_conditions = {
253
+ "model": self.model_name,
254
+ "n": "1",
255
+ "api_key": self.api_key,
256
+ "messageKey": message_key,
257
+ "outputType": "PULL",
258
+ "messages": messages,
259
+ }
260
+ query_conditions.update(self._build_openai_params(messages, temperature, max_tokens, stop, **kwargs))
261
+ param.update({
262
+ "serviceName": "asyn_chatgpt_prompts_completions_query_dataview",
263
+ "queryConditions": query_conditions,
264
+ })
265
+ return param
266
+
267
+ def _gen_message_key(self):
268
+ def _timestamp():
269
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
270
+ return timestamp
271
+
272
+ timestamp = _timestamp()
273
+ message_key = "llm_call_%s" % (timestamp)
274
+ return message_key
275
+
276
+ def _build_request_data(self, param: Dict[str, Any]):
277
+ param_data = json.dumps(param)
278
+ encrypted_param_data = self._aes_encrypt(param_data, self.aes_key)
279
+ post_data = {"encryptedParam": encrypted_param_data}
280
+ return post_data
281
+
282
+ def _build_chat_query_request_data(self,
283
+ message_key: str,
284
+ messages: List[Dict[str, str]],
285
+ temperature: float = 0.0,
286
+ max_tokens: int = None,
287
+ stop: List[str] = None,
288
+ **kwargs):
289
+ param = self._get_service_param(message_key, "request", messages, temperature, max_tokens, stop, **kwargs)
290
+ query_data = self._build_request_data(param)
291
+ return query_data
292
+
293
+ def _post_chat_query_request(self,
294
+ messages: List[Dict[str, str]],
295
+ temperature: float = 0.0,
296
+ max_tokens: int = None,
297
+ stop: List[str] = None,
298
+ **kwargs):
299
+ message_key = self._gen_message_key()
300
+ post_data = self._build_chat_query_request_data(message_key,
301
+ messages,
302
+ model_name=self.model_name,
303
+ temperature=temperature,
304
+ max_tokens=max_tokens,
305
+ stop=stop,
306
+ **kwargs)
307
+ response = self.http_provider.sync_call(post_data, endpoint="commonQuery/queryData")
308
+ return message_key, response
309
+
310
+ def _valid_chat_result(self, body):
311
+ if "data" not in body or not body["data"]:
312
+ return False
313
+ if "values" not in body["data"] or not body["data"]["values"]:
314
+ return False
315
+ if "response" not in body["data"]["values"] and "data" not in body["data"]["values"]:
316
+ return False
317
+ return True
318
+
319
+ def _build_chat_pull_request_data(self, message_key):
320
+ param = self._get_service_param(message_key, "pull")
321
+
322
+ pull_data = self._build_request_data(param)
323
+ return pull_data
324
+
325
+ def _pull_chat_result(self, message_key, response: Dict[str, Any], timeout):
326
+ if self.model_name.startswith("claude"):
327
+ if self._valid_chat_result(response):
328
+ x = response["data"]["values"]["data"]
329
+ ast_str = ast.literal_eval("'" + x + "'")
330
+ result = html.unescape(ast_str)
331
+ data = json.loads(result)
332
+ return data
333
+ else:
334
+ raise LLMResponseError(
335
+ f"Invalid response from Ant API, response: {response}",
336
+ self.model_name or "unknown"
337
+ )
338
+
339
+ post_data = self._build_chat_pull_request_data(message_key)
340
+ url = 'commonQuery/queryData'
341
+ headers = {
342
+ 'Content-Type': 'application/json'
343
+ }
344
+
345
+ # Start polling until valid result or timeout
346
+ start_time = time.time()
347
+ elapsed_time = 0
348
+
349
+ while elapsed_time < timeout:
350
+ response = self.http_provider.sync_call(post_data, endpoint=url, headers=headers)
351
+
352
+ logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}")
353
+
354
+ # Check if valid result is received
355
+ if self._valid_chat_result(response):
356
+ x = response["data"]["values"]["response"]
357
+ ast_str = ast.literal_eval("'" + x + "'")
358
+ result = html.unescape(ast_str)
359
+ data = json.loads(result)
360
+ return data
361
+ elif (not response.get("success")) or ("data" in response and response["data"]):
362
+ err_code = response.get("data", {}).get("errorCode", "")
363
+ err_msg = response.get("data", {}).get("errorMessage", "")
364
+ if err_code or err_msg:
365
+ raise LLMResponseError(
366
+ f"Request failed: {response}",
367
+ self.model_name or "unknown"
368
+ )
369
+
370
+ # If no result, wait 1 second and query again
371
+ time.sleep(1)
372
+ elapsed_time = time.time() - start_time
373
+ logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s")
374
+
375
+ # Timeout handling
376
+ raise LLMResponseError(
377
+ f"Timeout after {timeout} seconds waiting for response from Ant API",
378
+ self.model_name or "unknown"
379
+ )
380
+
381
+ async def _async_pull_chat_result(self, message_key, response: Dict[str, Any], timeout):
382
+ if self.model_name.startswith("claude"):
383
+ if self._valid_chat_result(response):
384
+ x = response["data"]["values"]["data"]
385
+ ast_str = ast.literal_eval("'" + x + "'")
386
+ result = html.unescape(ast_str)
387
+ data = json.loads(result)
388
+ return data
389
+ elif (not response.get("success")) or ("data" in response and response["data"]):
390
+ err_code = response.get("data", {}).get("errorCode", "")
391
+ err_msg = response.get("data", {}).get("errorMessage", "")
392
+ if err_code or err_msg:
393
+ raise LLMResponseError(
394
+ f"Request failed: {response}",
395
+ self.model_name or "unknown"
396
+ )
397
+
398
+ post_data = self._build_chat_pull_request_data(message_key)
399
+ url = 'commonQuery/queryData'
400
+ headers = {
401
+ 'Content-Type': 'application/json'
402
+ }
403
+
404
+ # Start polling until valid result or timeout
405
+ start_time = time.time()
406
+ elapsed_time = 0
407
+
408
+ while elapsed_time < timeout:
409
+ response = await self.http_provider.async_call(post_data, endpoint=url, headers=headers)
410
+
411
+ logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}")
412
+
413
+ # Check if valid result is received
414
+ if self._valid_chat_result(response):
415
+ x = response["data"]["values"]["response"]
416
+ ast_str = ast.literal_eval("'" + x + "'")
417
+ result = html.unescape(ast_str)
418
+ data = json.loads(result)
419
+ return data
420
+ elif (not response.get("success")) or ("data" in response and response["data"]):
421
+ err_code = response.get("data", {}).get("errorCode", "")
422
+ err_msg = response.get("data", {}).get("errorMessage", "")
423
+ if err_code or err_msg:
424
+ raise LLMResponseError(
425
+ f"Request failed: {response}",
426
+ self.model_name or "unknown"
427
+ )
428
+
429
+ # If no result, wait 1 second and query again
430
+ await asyncio.sleep(1)
431
+ elapsed_time = time.time() - start_time
432
+ logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s")
433
+
434
+ # Timeout handling
435
+ raise LLMResponseError(
436
+ f"Timeout after {timeout} seconds waiting for response from Ant API",
437
+ self.model_name or "unknown"
438
+ )
439
+
440
+ def _convert_completion_message(self, message: Dict[str, Any], is_finished: bool = False) -> ModelResponse:
441
+ """Convert Ant completion message to OpenAI format.
442
+
443
+ Args:
444
+ message: Ant completion message.
445
+
446
+ Returns:
447
+ OpenAI format message.
448
+ """
449
+ # Generate unique ID
450
+ response_id = f"ant-{hash(str(message)) & 0xffffffff:08x}"
451
+
452
+ # Get content
453
+ content = message.get("completion", "")
454
+
455
+ # Create message object
456
+ message_dict = {
457
+ "role": "assistant",
458
+ "content": content,
459
+ "is_chunk": True
460
+ }
461
+
462
+ # Keep original contextId and sessionId
463
+ if "contextId" in message:
464
+ message_dict["contextId"] = message["contextId"]
465
+ if "sessionId" in message:
466
+ message_dict["sessionId"] = message["sessionId"]
467
+
468
+ usage = {
469
+ "completion_tokens": message.get("completionToken", 0),
470
+ "prompt_tokens": message.get("promptTokens", 0),
471
+ "total_tokens": message.get("completionToken", 0) + message.get("promptTokens", 0)
472
+ }
473
+
474
+ # process tool calls
475
+ tool_calls = message.get("toolCalls", [])
476
+ for tool_call in tool_calls:
477
+ index = tool_call.get("index", 0)
478
+ name = tool_call.get("function", {}).get("name")
479
+ arguments = tool_call.get("function", {}).get("arguments")
480
+ if index >= len(self.stream_tool_buffer):
481
+ self.stream_tool_buffer.append({
482
+ "id": tool_call.get("id"),
483
+ "type": "function",
484
+ "function": {
485
+ "name": name,
486
+ "arguments": arguments
487
+ }
488
+ })
489
+ else:
490
+ self.stream_tool_buffer[index]["function"]["arguments"] += arguments
491
+
492
+ if is_finished and self.stream_tool_buffer:
493
+ message_dict["tool_calls"] = self.stream_tool_buffer.copy()
494
+ processed_tool_calls = []
495
+ for tool_call in self.stream_tool_buffer:
496
+ processed_tool_calls.append(ToolCall.from_dict(tool_call))
497
+ tool_resp = ModelResponse(
498
+ id=response_id,
499
+ model=self.model_name or "ant",
500
+ content=content,
501
+ tool_calls=processed_tool_calls,
502
+ usage=usage,
503
+ raw_response=message,
504
+ message=message_dict
505
+ )
506
+ self.stream_tool_buffer = []
507
+ return tool_resp
508
+
509
+ # Build and return ModelResponse object directly
510
+ return ModelResponse(
511
+ id=response_id,
512
+ model=self.model_name or "ant",
513
+ content=content,
514
+ tool_calls=None, # TODO: add tool calls
515
+ usage=usage,
516
+ raw_response=message,
517
+ message=message_dict
518
+ )
519
+
520
+ def preprocess_stream_call_message(self, messages: List[Dict[str, str]], ext_params: Dict[str, Any]) -> Dict[
521
+ str, str]:
522
+ """Preprocess messages, use Ant format directly.
523
+
524
+ Args:
525
+ messages: Ant format message list.
526
+
527
+ Returns:
528
+ Processed message list.
529
+ """
530
+ param = {
531
+ "messages": messages,
532
+ "sessionId": "TkQUldjzOgYSKyTrpor3TA==",
533
+ "model": self.model_name,
534
+ "needMemory": False,
535
+ "stream": True,
536
+ "contextId": "contextId_34555fd2d246447fa55a1a259445a427",
537
+ "platform": "AWorld"
538
+ }
539
+ for k in ext_params.keys():
540
+ if k not in param:
541
+ param[k] = ext_params[k]
542
+ return param
543
+
544
+ def postprocess_response(self, response: Any) -> ModelResponse:
545
+ """Process Ant response.
546
+
547
+ Args:
548
+ response: Ant response object.
549
+
550
+ Returns:
551
+ ModelResponse object.
552
+
553
+ Raises:
554
+ LLMResponseError: When LLM response error occurs.
555
+ """
556
+ if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices))
557
+ or (isinstance(response, dict) and not response.get("choices"))):
558
+ error_msg = ""
559
+ if hasattr(response, 'error') and response.error and isinstance(response.error, dict):
560
+ error_msg = response.error.get('message', '')
561
+ elif hasattr(response, 'msg'):
562
+ error_msg = response.msg
563
+
564
+ raise LLMResponseError(
565
+ error_msg if error_msg else "Unknown error",
566
+ self.model_name or "unknown",
567
+ response
568
+ )
569
+
570
+ return ModelResponse.from_openai_response(response)
571
+
572
+ def postprocess_stream_response(self, chunk: Any) -> ModelResponse:
573
+ """Process Ant stream response chunk.
574
+
575
+ Args:
576
+ chunk: Ant response chunk.
577
+
578
+ Returns:
579
+ ModelResponse object.
580
+
581
+ Raises:
582
+ LLMResponseError: When LLM response error occurs.
583
+ """
584
+ # Check if chunk contains error
585
+ if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')):
586
+ error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error')
587
+ raise LLMResponseError(
588
+ error_msg,
589
+ self.model_name or "unknown",
590
+ chunk
591
+ )
592
+
593
+ if isinstance(chunk, dict) and ('completion' in chunk):
594
+ return self._convert_completion_message(chunk)
595
+
596
+ # If chunk is already in OpenAI format, use standard processing method
597
+ return ModelResponse.from_openai_stream_chunk(chunk)
598
+
599
+ def completion(self,
600
+ messages: List[Dict[str, str]],
601
+ temperature: float = 0.0,
602
+ max_tokens: int = None,
603
+ stop: List[str] = None,
604
+ **kwargs) -> ModelResponse:
605
+ """Synchronously call Ant to generate response.
606
+
607
+ Args:
608
+ messages: Message list.
609
+ temperature: Temperature parameter.
610
+ max_tokens: Maximum number of tokens to generate.
611
+ stop: List of stop sequences.
612
+ **kwargs: Other parameters.
613
+
614
+ Returns:
615
+ ModelResponse object.
616
+
617
+ Raises:
618
+ LLMResponseError: When LLM response error occurs.
619
+ """
620
+ if not self.provider:
621
+ raise RuntimeError(
622
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
623
+
624
+ try:
625
+ start_time = time.time()
626
+ message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs)
627
+ timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180))
628
+ result = self._pull_chat_result(message_key, response, timeout)
629
+ logger.info(f"completion cost time: {time.time() - start_time}s.")
630
+
631
+ resp = self.postprocess_response(result)
632
+ usage_process(resp.usage)
633
+ return resp
634
+ except Exception as e:
635
+ if isinstance(e, LLMResponseError):
636
+ raise e
637
+ logger.warn(f"Error in Ant completion: {e}")
638
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
639
+
640
+ async def acompletion(self,
641
+ messages: List[Dict[str, str]],
642
+ temperature: float = 0.0,
643
+ max_tokens: int = None,
644
+ stop: List[str] = None,
645
+ **kwargs) -> ModelResponse:
646
+ """Asynchronously call Ant to generate response.
647
+
648
+ Args:
649
+ messages: Message list.
650
+ temperature: Temperature parameter.
651
+ max_tokens: Maximum number of tokens to generate.
652
+ stop: List of stop sequences.
653
+ **kwargs: Other parameters.
654
+
655
+ Returns:
656
+ ModelResponse object.
657
+
658
+ Raises:
659
+ LLMResponseError: When LLM response error occurs.
660
+ """
661
+ if not self.async_provider:
662
+ self._init_async_provider()
663
+
664
+ start_time = time.time()
665
+ try:
666
+ message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs)
667
+ timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180))
668
+ result = await self._async_pull_chat_result(message_key, response, timeout)
669
+ logger.info(f"completion cost time: {time.time() - start_time}s.")
670
+
671
+ resp = self.postprocess_response(result)
672
+ usage_process(resp.usage)
673
+ return resp
674
+
675
+ except Exception as e:
676
+ if isinstance(e, LLMResponseError):
677
+ raise e
678
+ logger.warn(f"Error in async Ant completion: {e}")
679
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
680
+
681
+ def stream_completion(self,
682
+ messages: List[Dict[str, str]],
683
+ temperature: float = 0.0,
684
+ max_tokens: int = None,
685
+ stop: List[str] = None,
686
+ **kwargs) -> Generator[ModelResponse, None, None]:
687
+ """Synchronously call Ant to generate streaming response.
688
+
689
+ Args:
690
+ messages: Message list.
691
+ temperature: Temperature parameter.
692
+ max_tokens: Maximum number of tokens to generate.
693
+ stop: List of stop sequences.
694
+ **kwargs: Other parameters.
695
+
696
+ Returns:
697
+ Generator yielding ModelResponse chunks.
698
+
699
+ Raises:
700
+ LLMResponseError: When LLM response error occurs.
701
+ """
702
+ if not self.provider:
703
+ raise RuntimeError(
704
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
705
+
706
+ start_time = time.time()
707
+ # Generate message_key
708
+ timestamp = int(time.time())
709
+ self.message_key = f"llm_call_{timestamp}"
710
+ message_key_literal = self.message_key # Ensure it's a direct string literal
711
+ self.aes_key = kwargs.get("aes_key", self.aes_key)
712
+
713
+ # Add streaming parameter
714
+ kwargs["stream"] = True
715
+ processed_messages = self.preprocess_stream_call_message(messages,
716
+ self._build_openai_params(temperature, max_tokens,
717
+ stop, **kwargs))
718
+ if not processed_messages:
719
+ raise LLMResponseError("Failed to get post data", self.model_name or "unknown")
720
+
721
+ usage = {
722
+ "prompt_tokens": 0,
723
+ "completion_tokens": 0,
724
+ "total_tokens": 0
725
+ }
726
+
727
+ try:
728
+ # Send request
729
+ # response = self.http_provider.sync_call(processed_messages[0], endpoint="commonQuery/queryData")
730
+ headers = {
731
+ "Content-Type": "application/json",
732
+ "X_ACCESS_KEY": self.stream_api_key
733
+ }
734
+ response_stream = self.http_provider.sync_stream_call(processed_messages, endpoint="chat/completions",
735
+ headers=headers)
736
+ if response_stream:
737
+ for chunk in response_stream:
738
+ if not chunk:
739
+ continue
740
+
741
+ # Process special markers
742
+ if isinstance(chunk, dict) and "status" in chunk:
743
+ if chunk["status"] == "done":
744
+ # Stream completion marker, can choose to end
745
+ logger.info("Received [DONE] marker, stream completed")
746
+ yield self._convert_completion_message(chunk, is_finished=True)
747
+ yield ModelResponse.from_special_marker("done", self.model_name, chunk)
748
+ break
749
+ elif chunk["status"] == "revoke":
750
+ # Revoke marker, need to notify the frontend to revoke the displayed content
751
+ logger.info("Received [REVOKE] marker, content should be revoked")
752
+ yield ModelResponse.from_special_marker("revoke", self.model_name, chunk)
753
+ continue
754
+ elif chunk["status"] == "fail":
755
+ # Fail marker
756
+ logger.error("Received [FAIL] marker, request failed")
757
+ raise LLMResponseError("Request failed", self.model_name or "unknown")
758
+ elif chunk["status"] == "cancel":
759
+ # Request was cancelled
760
+ logger.warning("Received [CANCEL] marker, stream was cancelled")
761
+ raise LLMResponseError("Stream was cancelled", self.model_name or "unknown")
762
+ continue
763
+
764
+ # Process normal response chunks
765
+ resp = self.postprocess_stream_response(chunk)
766
+ self._accumulate_chunk_usage(usage, resp.usage)
767
+ yield resp
768
+ usage_process(usage)
769
+
770
+ logger.info(f"stream_completion cost time: {time.time() - start_time}s.")
771
+ except Exception as e:
772
+ if isinstance(e, LLMResponseError):
773
+ raise e
774
+ logger.error(f"Error in Ant stream completion: {e}")
775
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
776
+
777
+ async def astream_completion(self,
778
+ messages: List[Dict[str, str]],
779
+ temperature: float = 0.0,
780
+ max_tokens: int = None,
781
+ stop: List[str] = None,
782
+ **kwargs) -> AsyncGenerator[ModelResponse, None]:
783
+ """Asynchronously call Ant to generate streaming response.
784
+
785
+ Args:
786
+ messages: Message list.
787
+ temperature: Temperature parameter.
788
+ max_tokens: Maximum number of tokens to generate.
789
+ stop: List of stop sequences.
790
+ **kwargs: Other parameters.
791
+
792
+ Returns:
793
+ AsyncGenerator yielding ModelResponse chunks.
794
+
795
+ Raises:
796
+ LLMResponseError: When LLM response error occurs.
797
+ """
798
+ if not self.async_provider:
799
+ self._init_async_provider()
800
+
801
+ start_time = time.time()
802
+ # Generate message_key
803
+ timestamp = int(time.time())
804
+ self.message_key = f"llm_call_{timestamp}"
805
+ message_key_literal = self.message_key # Ensure it's a direct string literal
806
+ self.aes_key = kwargs.get("aes_key", self.aes_key)
807
+
808
+ # Add streaming parameter
809
+ kwargs["stream"] = True
810
+ processed_messages = self.preprocess_stream_call_message(messages,
811
+ self._build_openai_params(temperature, max_tokens,
812
+ stop, **kwargs))
813
+ if not processed_messages:
814
+ raise LLMResponseError("Failed to get post data", self.model_name or "unknown")
815
+
816
+ usage = {
817
+ "prompt_tokens": 0,
818
+ "completion_tokens": 0,
819
+ "total_tokens": 0
820
+ }
821
+ try:
822
+ headers = {
823
+ "Content-Type": "application/json",
824
+ "X_ACCESS_KEY": self.stream_api_key
825
+ }
826
+ logger.info(f"astream_completion request data: {processed_messages}")
827
+
828
+ async for chunk in self.http_provider.async_stream_call(processed_messages, endpoint="chat/completions",
829
+ headers=headers):
830
+ if not chunk:
831
+ continue
832
+
833
+ # Process special markers
834
+ if isinstance(chunk, dict) and "status" in chunk:
835
+ if chunk["status"] == "done":
836
+ # Stream completion marker, can choose to end
837
+ logger.info("Received [DONE] marker, stream completed")
838
+ yield ModelResponse.from_special_marker("done", self.model_name, chunk)
839
+ break
840
+ elif chunk["status"] == "revoke":
841
+ # Revoke marker, need to notify the frontend to revoke the displayed content
842
+ logger.info("Received [REVOKE] marker, content should be revoked")
843
+ yield ModelResponse.from_special_marker("revoke", self.model_name, chunk)
844
+ continue
845
+ elif chunk["status"] == "fail":
846
+ # Fail marker
847
+ logger.error("Received [FAIL] marker, request failed")
848
+ raise LLMResponseError("Request failed", self.model_name or "unknown")
849
+ elif chunk["status"] == "cancel":
850
+ # Request was cancelled
851
+ logger.warning("Received [CANCEL] marker, stream was cancelled")
852
+ raise LLMResponseError("Stream was cancelled", self.model_name or "unknown")
853
+ continue
854
+
855
+ # Process normal response chunks
856
+ resp = self.postprocess_stream_response(chunk)
857
+ self._accumulate_chunk_usage(usage, resp.usage)
858
+ yield resp
859
+ usage_process(usage)
860
+
861
+ logger.info(f"astream_completion cost time: {time.time() - start_time}s.")
862
+ except Exception as e:
863
+ if isinstance(e, LLMResponseError):
864
+ raise e
865
+ logger.warn(f"Error in async Ant stream completion: {e}")
866
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
aworld/models/anthropic_provider.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Generator, AsyncGenerator
3
+
4
+ from aworld.utils import import_package
5
+ from aworld.logs.util import logger
6
+ from aworld.core.llm_provider_base import LLMProviderBase
7
+ from aworld.models.model_response import ModelResponse, LLMResponseError
8
+
9
+
10
+ class AnthropicProvider(LLMProviderBase):
11
+ """Anthropic provider implementation.
12
+ """
13
+
14
+ def __init__(self,
15
+ api_key: str = None,
16
+ base_url: str = None,
17
+ model_name: str = None,
18
+ sync_enabled: bool = None,
19
+ async_enabled: bool = None,
20
+ **kwargs):
21
+ super().__init__(api_key, base_url, model_name, sync_enabled, async_enabled, **kwargs)
22
+ import_package("anthropic")
23
+
24
+ def _init_provider(self):
25
+ """Initialize Anthropic provider.
26
+
27
+ Returns:
28
+ Anthropic provider instance.
29
+ """
30
+ from anthropic import Anthropic
31
+
32
+ # Get API key
33
+ api_key = self.api_key
34
+ if not api_key:
35
+ env_var = "ANTHROPIC_API_KEY"
36
+ api_key = os.getenv(env_var, "")
37
+ if not api_key:
38
+ raise ValueError(
39
+ f"Anthropic API key not found, please set {env_var} environment variable or provide it in the parameters")
40
+
41
+ return Anthropic(
42
+ api_key=api_key,
43
+ base_url=self.base_url
44
+ )
45
+
46
+ def _init_async_provider(self):
47
+ """Initialize async Anthropic provider.
48
+
49
+ Returns:
50
+ Async Anthropic provider instance.
51
+ """
52
+ from anthropic import Anthropic, AsyncAnthropic
53
+
54
+ # Get API key
55
+ api_key = self.api_key
56
+ if not api_key:
57
+ env_var = "ANTHROPIC_API_KEY"
58
+ api_key = os.getenv(env_var, "")
59
+ if not api_key:
60
+ raise ValueError(
61
+ f"Anthropic API key not found, please set {env_var} environment variable or provide it in the parameters")
62
+
63
+ return AsyncAnthropic(
64
+ api_key=api_key,
65
+ base_url=self.base_url
66
+ )
67
+
68
+ @classmethod
69
+ def supported_models(cls) -> list[str]:
70
+ return [r"claude-3-.*"]
71
+
72
+ def preprocess_messages(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
73
+ """Preprocess messages, convert OpenAI format to Anthropic format.
74
+
75
+ Args:
76
+ messages: OpenAI format message list.
77
+
78
+ Returns:
79
+ Converted message dictionary, containing messages and system fields.
80
+ """
81
+ anthropic_messages = []
82
+ system_content = None
83
+
84
+ for msg in messages:
85
+ role = msg.get("role", "")
86
+ content = msg.get("content", "")
87
+
88
+ if role == "system":
89
+ system_content = content
90
+ elif role == "user":
91
+ anthropic_messages.append({"role": "user", "content": content})
92
+ elif role == "assistant":
93
+ anthropic_messages.append({"role": "assistant", "content": content})
94
+
95
+ return {
96
+ "messages": anthropic_messages,
97
+ "system": system_content
98
+ }
99
+
100
+ def postprocess_response(self, response: Any) -> ModelResponse:
101
+ """Process Anthropic response to unified ModelResponse.
102
+
103
+ Args:
104
+ response: Anthropic response object.
105
+
106
+ Returns:
107
+ ModelResponse object.
108
+
109
+ Raises:
110
+ LLMResponseError: When LLM response error occurs.
111
+ """
112
+ # Check if response is empty or contains error
113
+ if not response or (isinstance(response, dict) and response.get('error')):
114
+ error_msg = response.get('error', 'Unknown error') if isinstance(response, dict) else 'Empty response'
115
+ raise LLMResponseError(error_msg, self.model_name or "claude", response)
116
+
117
+ return ModelResponse.from_anthropic_response(response)
118
+
119
+ def postprocess_stream_response(self, chunk: Any) -> ModelResponse:
120
+ """Process Anthropic streaming response chunk.
121
+
122
+ Args:
123
+ chunk: Anthropic response chunk.
124
+
125
+ Returns:
126
+ ModelResponse object.
127
+
128
+ Raises:
129
+ LLMResponseError: When LLM response error occurs.
130
+ """
131
+ # Check if chunk is empty or contains error
132
+ if not chunk or (isinstance(chunk, dict) and chunk.get('error')):
133
+ error_msg = chunk.get('error', 'Unknown error') if isinstance(chunk, dict) else 'Empty response'
134
+ raise LLMResponseError(error_msg, self.model_name or "claude", chunk)
135
+
136
+ return ModelResponse.from_anthropic_stream_chunk(chunk)
137
+
138
+ def completion(self,
139
+ messages: List[Dict[str, str]],
140
+ temperature: float = 0.0,
141
+ max_tokens: int = None,
142
+ stop: List[str] = None,
143
+ **kwargs) -> ModelResponse:
144
+ """Synchronously call Anthropic to generate response.
145
+
146
+ Args:
147
+ messages: Message list.
148
+ temperature: Temperature parameter.
149
+ max_tokens: Maximum number of tokens to generate.
150
+ stop: List of stop sequences.
151
+ **kwargs: Other parameters.
152
+
153
+ Returns:
154
+ ModelResponse object.
155
+ """
156
+ if not self.provider:
157
+ raise RuntimeError(
158
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
159
+
160
+ try:
161
+ processed_data = self.preprocess_messages(messages)
162
+ processed_messages = processed_data["messages"]
163
+ system_content = processed_data["system"]
164
+ anthropic_params = self.get_anthropic_params(processed_messages, system_content, temperature, max_tokens,
165
+ stop, **kwargs)
166
+ response = self.provider.visited_messages.create(**anthropic_params)
167
+
168
+ return self.postprocess_response(response)
169
+ except Exception as e:
170
+ logger.warn(f"Error in Anthropic completion: {e}")
171
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "claude"))
172
+
173
+ def stream_completion(self,
174
+ messages: List[Dict[str, str]],
175
+ temperature: float = 0.0,
176
+ max_tokens: int = None,
177
+ stop: List[str] = None,
178
+ **kwargs) -> Generator[ModelResponse, None, None]:
179
+ """Synchronously call Anthropic to generate streaming response.
180
+
181
+ Args:
182
+ messages: Message list.
183
+ temperature: Temperature parameter.
184
+ max_tokens: Maximum number of tokens to generate.
185
+ stop: List of stop sequences.
186
+ **kwargs: Other parameters.
187
+
188
+ Returns:
189
+ Generator yielding ModelResponse chunks.
190
+ """
191
+ if not self.provider:
192
+ raise RuntimeError(
193
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
194
+
195
+ try:
196
+ processed_data = self.preprocess_messages(messages)
197
+ processed_messages = processed_data["messages"]
198
+ system_content = processed_data["system"]
199
+ anthropic_params = self.get_anthropic_params(processed_messages, system_content, temperature, max_tokens,
200
+ stop, **kwargs)
201
+ anthropic_params["stream"] = True
202
+ response_stream = self.provider.visited_messages.create(**anthropic_params)
203
+
204
+ for chunk in response_stream:
205
+ if not chunk:
206
+ continue
207
+
208
+ yield self.postprocess_stream_response(chunk)
209
+
210
+ except Exception as e:
211
+ logger.warn(f"Error in Anthropic stream_completion: {e}")
212
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "claude"))
213
+
214
+ async def astream_completion(self,
215
+ messages: List[Dict[str, str]],
216
+ temperature: float = 0.0,
217
+ max_tokens: int = None,
218
+ stop: List[str] = None,
219
+ **kwargs) -> AsyncGenerator[ModelResponse, None]:
220
+ """Asynchronously call Anthropic to generate streaming response.
221
+
222
+ Args:
223
+ messages: Message list.
224
+ temperature: Temperature parameter.
225
+ max_tokens: Maximum number of tokens to generate.
226
+ stop: List of stop sequences.
227
+ **kwargs: Other parameters.
228
+
229
+ Returns:
230
+ AsyncGenerator yielding ModelResponse chunks.
231
+ """
232
+ if not self.async_provider:
233
+ raise RuntimeError(
234
+ "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
235
+
236
+ try:
237
+ processed_data = self.preprocess_messages(messages)
238
+ processed_messages = processed_data["messages"]
239
+ system_content = processed_data["system"]
240
+ anthropic_params = self.get_anthropic_params(processed_messages, system_content, temperature, max_tokens,
241
+ stop, **kwargs)
242
+ anthropic_params["stream"] = True
243
+ response_stream = await self.async_provider.visited_messages.create(**anthropic_params)
244
+
245
+ async for chunk in response_stream:
246
+ if not chunk:
247
+ continue
248
+
249
+ yield self.postprocess_stream_response(chunk)
250
+
251
+ except Exception as e:
252
+ logger.warn(f"Error in Anthropic astream_completion: {e}")
253
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "claude"))
254
+
255
+ async def acompletion(self,
256
+ messages: List[Dict[str, str]],
257
+ temperature: float = 0.0,
258
+ max_tokens: int = None,
259
+ stop: List[str] = None,
260
+ **kwargs) -> ModelResponse:
261
+ """Asynchronously call Anthropic to generate response.
262
+
263
+ Args:
264
+ messages: Message list.
265
+ temperature: Temperature parameter.
266
+ max_tokens: Maximum number of tokens to generate.
267
+ stop: List of stop sequences.
268
+ **kwargs: Other parameters.
269
+
270
+ Returns:
271
+ ModelResponse object.
272
+ """
273
+ if not self.async_provider:
274
+ raise RuntimeError(
275
+ "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
276
+
277
+ try:
278
+ processed_data = self.preprocess_messages(messages)
279
+ processed_messages = processed_data["messages"]
280
+ system_content = processed_data["system"]
281
+ anthropic_params = self.get_anthropic_params(processed_messages, system_content, temperature, max_tokens,
282
+ stop, **kwargs)
283
+ response = await self.async_provider.visited_messages.create(**anthropic_params)
284
+
285
+ return self.postprocess_response(response)
286
+ except Exception as e:
287
+ logger.warn(f"Error in Anthropic acompletion: {e}")
288
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "claude"))
289
+
290
+ def get_anthropic_params(self,
291
+ messages: List[Dict[str, str]],
292
+ system: str = None,
293
+ temperature: float = 0.0,
294
+ max_tokens: int = None,
295
+ stop: List[str] = None,
296
+ **kwargs) -> Dict[str, Any]:
297
+ if "tools" in kwargs:
298
+ openai_tools = kwargs["tools"]
299
+ claude_tools = []
300
+
301
+ for tool in openai_tools:
302
+ if tool["type"] == "function":
303
+ claude_tool = {
304
+ "name": tool["name"],
305
+ "description": tool["description"],
306
+ "input_schema": {
307
+ "type": "object",
308
+ "properties": tool["parameters"]["properties"],
309
+ "required": tool["parameters"].get("required", [])
310
+ }
311
+ }
312
+ claude_tools.append(claude_tool)
313
+
314
+ kwargs["tools"] = claude_tools
315
+
316
+ anthropic_params = {
317
+ "model": kwargs.get("model_name", self.model_name or ""),
318
+ "messages": messages,
319
+ "system": system,
320
+ "temperature": temperature,
321
+ "max_tokens": max_tokens or 4096,
322
+ "stop_sequences": stop,
323
+ }
324
+
325
+ if "tools" in kwargs and kwargs["tools"]:
326
+ anthropic_params["tools"] = kwargs["tools"]
327
+ anthropic_params["tool_choice"] = kwargs.get("tool_choice", "auto")
328
+
329
+ for param in ["top_p", "top_k", "metadata", "stream"]:
330
+ if param in kwargs:
331
+ anthropic_params[param] = kwargs[param]
332
+
333
+ return anthropic_params
aworld/models/llm.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (
2
+ List,
3
+ Dict,
4
+ Union,
5
+ Generator,
6
+ AsyncGenerator,
7
+ )
8
+ from aworld.config import ConfigDict
9
+ from aworld.config.conf import AgentConfig, ClientType
10
+ from aworld.logs.util import logger
11
+
12
+ from aworld.core.llm_provider_base import LLMProviderBase
13
+ from aworld.models.openai_provider import OpenAIProvider, AzureOpenAIProvider
14
+ from aworld.models.anthropic_provider import AnthropicProvider
15
+ from aworld.models.ant_provider import AntProvider
16
+ from aworld.models.model_response import ModelResponse
17
+
18
+ # Predefined model names for common providers
19
+ MODEL_NAMES = {
20
+ "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
21
+ "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"],
22
+ "azure_openai": ["gpt-4", "gpt-4-turbo", "gpt-4o", "gpt-35-turbo"],
23
+ }
24
+
25
+ # Endpoint patterns for identifying providers
26
+ ENDPOINT_PATTERNS = {
27
+ "openai": ["api.openai.com"],
28
+ "anthropic": ["api.anthropic.com", "claude-api"],
29
+ "azure_openai": ["openai.azure.com"],
30
+ "ant": ["zdfmng.alipay.com"],
31
+ }
32
+
33
+ # Provider class mapping
34
+ PROVIDER_CLASSES = {
35
+ "openai": OpenAIProvider,
36
+ "anthropic": AnthropicProvider,
37
+ "azure_openai": AzureOpenAIProvider,
38
+ "ant": AntProvider,
39
+ }
40
+
41
+
42
+ class LLMModel:
43
+ """Unified large model interface, encapsulates different model implementations, provides a unified completion method.
44
+ """
45
+
46
+ def __init__(self, conf: Union[ConfigDict, AgentConfig] = None, custom_provider: LLMProviderBase = None, **kwargs):
47
+ """Initialize unified model interface.
48
+
49
+ Args:
50
+ conf: Agent configuration, if provided, create model based on configuration.
51
+ custom_provider: Custom LLMProviderBase instance, if provided, use it directly.
52
+ **kwargs: Other parameters, may include:
53
+ - base_url: Specify model endpoint.
54
+ - api_key: API key.
55
+ - model_name: Model name.
56
+ - temperature: Temperature parameter.
57
+ """
58
+
59
+ # If custom_provider instance is provided, use it directly
60
+ if custom_provider is not None:
61
+ if not isinstance(custom_provider, LLMProviderBase):
62
+ raise TypeError(
63
+ "custom_provider must be an instance of LLMProviderBase")
64
+ self.provider_name = "custom"
65
+ self.provider = custom_provider
66
+ return
67
+
68
+ # Get basic parameters
69
+ base_url = kwargs.get("base_url") or (
70
+ conf.llm_base_url if conf else None)
71
+ model_name = kwargs.get("model_name") or (
72
+ conf.llm_model_name if conf else None)
73
+ llm_provider = conf.llm_provider if conf_contains_key(
74
+ conf, "llm_provider") else None
75
+
76
+ # Get API key from configuration (if any)
77
+ if conf and conf.llm_api_key:
78
+ kwargs["api_key"] = conf.llm_api_key
79
+
80
+ # Identify provider
81
+ self.provider_name = self._identify_provider(
82
+ llm_provider, base_url, model_name)
83
+
84
+ # Fill basic parameters
85
+ kwargs['base_url'] = base_url
86
+ kwargs['model_name'] = model_name
87
+
88
+ # Fill parameters for llm provider
89
+ kwargs['sync_enabled'] = conf.llm_sync_enabled if conf_contains_key(
90
+ conf, "llm_sync_enabled") else True
91
+ kwargs['async_enabled'] = conf.llm_async_enabled if conf_contains_key(
92
+ conf, "llm_async_enabled") else True
93
+ kwargs['client_type'] = conf.llm_client_type if conf_contains_key(
94
+ conf, "llm_client_type") else ClientType.SDK
95
+
96
+ kwargs.update(self._transfer_conf_to_args(conf))
97
+
98
+ # Create model provider based on provider_name
99
+ self._create_provider(**kwargs)
100
+
101
+ def _transfer_conf_to_args(self, conf: Union[ConfigDict, AgentConfig] = None) -> dict:
102
+ """
103
+ Transfer parameters from conf to args
104
+
105
+ Args:
106
+ conf: config object
107
+ """
108
+ if not conf:
109
+ return {}
110
+
111
+ # Get all parameters from conf
112
+ if type(conf).__name__ == 'AgentConfig':
113
+ conf_dict = conf.model_dump()
114
+ else: # ConfigDict
115
+ conf_dict = conf
116
+
117
+ ignored_keys = ["llm_provider", "llm_base_url", "llm_model_name", "llm_api_key", "llm_sync_enabled",
118
+ "llm_async_enabled", "llm_client_type"]
119
+ args = {}
120
+ # Filter out used parameters and add remaining parameters to args
121
+ for key, value in conf_dict.items():
122
+ if key not in ignored_keys and value is not None:
123
+ args[key] = value
124
+
125
+ return args
126
+
127
+ def _identify_provider(self, provider: str = None, base_url: str = None, model_name: str = None) -> str:
128
+ """Identify LLM provider.
129
+
130
+ Identification logic:
131
+ 1. If provider is specified and doesn't need to be overridden, use the specified provider.
132
+ 2. If base_url is provided, try to identify provider based on base_url.
133
+ 3. If model_name is provided, try to identify provider based on model_name.
134
+ 4. If none can be identified, default to "openai".
135
+
136
+ Args:
137
+ provider: Specified provider.
138
+ base_url: Service URL.
139
+ model_name: Model name.
140
+
141
+ Returns:
142
+ str: Identified provider.
143
+ """
144
+ # Default provider
145
+ identified_provider = "openai"
146
+
147
+ # Identify provider based on base_url
148
+ if base_url:
149
+ for p, patterns in ENDPOINT_PATTERNS.items():
150
+ if any(pattern in base_url for pattern in patterns):
151
+ identified_provider = p
152
+ logger.info(
153
+ f"Identified provider: {identified_provider} based on base_url: {base_url}")
154
+ return identified_provider
155
+
156
+ # Identify provider based on model_name
157
+ if model_name and not base_url:
158
+ for p, models in MODEL_NAMES.items():
159
+ if model_name in models or any(model_name.startswith(model) for model in models):
160
+ identified_provider = p
161
+ logger.info(
162
+ f"Identified provider: {identified_provider} based on model_name: {model_name}")
163
+ break
164
+
165
+ if provider and provider in PROVIDER_CLASSES and identified_provider and identified_provider != provider:
166
+ logger.warning(
167
+ f"Provider mismatch: {provider} != {identified_provider}, using {provider} as provider")
168
+ identified_provider = provider
169
+
170
+ return identified_provider
171
+
172
+ def _create_provider(self, **kwargs):
173
+ """Return the corresponding provider instance based on provider.
174
+
175
+ Args:
176
+ **kwargs: Parameters, may include:
177
+ - base_url: Model endpoint.
178
+ - api_key: API key.
179
+ - model_name: Model name.
180
+ - temperature: Temperature parameter.
181
+ - timeout: Timeout.
182
+ - max_retries: Maximum number of retries.
183
+ """
184
+ self.provider = PROVIDER_CLASSES[self.provider_name](**kwargs)
185
+
186
+ @classmethod
187
+ def supported_providers(cls) -> list[str]:
188
+ return list(PROVIDER_CLASSES.keys())
189
+
190
+ def supported_models(self) -> list[str]:
191
+ """Get supported models for the current provider.
192
+ Returns:
193
+ list: Supported models.
194
+ """
195
+ return self.provider.supported_models() if self.provider else []
196
+
197
+ async def acompletion(self,
198
+ messages: List[Dict[str, str]],
199
+ temperature: float = 0.0,
200
+ max_tokens: int = None,
201
+ stop: List[str] = None,
202
+ **kwargs) -> ModelResponse:
203
+ """Asynchronously call model to generate response.
204
+
205
+ Args:
206
+ messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
207
+ temperature: Temperature parameter.
208
+ max_tokens: Maximum number of tokens to generate.
209
+ stop: List of stop sequences.
210
+ **kwargs: Other parameters.
211
+
212
+ Returns:
213
+ ModelResponse: Unified model response object.
214
+ """
215
+ # Call provider's acompletion method directly
216
+ return await self.provider.acompletion(
217
+ messages=messages,
218
+ temperature=temperature,
219
+ max_tokens=max_tokens,
220
+ stop=stop,
221
+ **kwargs
222
+ )
223
+
224
+ def completion(self,
225
+ messages: List[Dict[str, str]],
226
+ temperature: float = 0.0,
227
+ max_tokens: int = None,
228
+ stop: List[str] = None,
229
+ **kwargs) -> ModelResponse:
230
+ """Synchronously call model to generate response.
231
+
232
+ Args:
233
+ messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
234
+ temperature: Temperature parameter.
235
+ max_tokens: Maximum number of tokens to generate.
236
+ stop: List of stop sequences.
237
+ **kwargs: Other parameters.
238
+
239
+ Returns:
240
+ ModelResponse: Unified model response object.
241
+ """
242
+ # Call provider's completion method directly
243
+ return self.provider.completion(
244
+ messages=messages,
245
+ temperature=temperature,
246
+ max_tokens=max_tokens,
247
+ stop=stop,
248
+ **kwargs
249
+ )
250
+
251
+ def stream_completion(self,
252
+ messages: List[Dict[str, str]],
253
+ temperature: float = 0.0,
254
+ max_tokens: int = None,
255
+ stop: List[str] = None,
256
+ **kwargs) -> Generator[ModelResponse, None, None]:
257
+ """Synchronously call model to generate streaming response.
258
+
259
+ Args:
260
+ messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
261
+ temperature: Temperature parameter.
262
+ max_tokens: Maximum number of tokens to generate.
263
+ stop: List of stop sequences.
264
+ **kwargs: Other parameters.
265
+
266
+ Returns:
267
+ Generator yielding ModelResponse chunks.
268
+ """
269
+ # Call provider's stream_completion method directly
270
+ return self.provider.stream_completion(
271
+ messages=messages,
272
+ temperature=temperature,
273
+ max_tokens=max_tokens,
274
+ stop=stop,
275
+ **kwargs
276
+ )
277
+
278
+ async def astream_completion(self,
279
+ messages: List[Dict[str, str]],
280
+ temperature: float = 0.0,
281
+ max_tokens: int = None,
282
+ stop: List[str] = None,
283
+ **kwargs) -> AsyncGenerator[ModelResponse, None]:
284
+ """Asynchronously call model to generate streaming response.
285
+
286
+ Args:
287
+ messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
288
+ temperature: Temperature parameter.
289
+ max_tokens: Maximum number of tokens to generate.
290
+ stop: List of stop sequences.
291
+ **kwargs: Other parameters, may include:
292
+ - base_url: Specify model endpoint.
293
+ - api_key: API key.
294
+ - model_name: Model name.
295
+
296
+ Returns:
297
+ AsyncGenerator yielding ModelResponse chunks.
298
+ """
299
+ # Call provider's astream_completion method directly
300
+ async for chunk in self.provider.astream_completion(
301
+ messages=messages,
302
+ temperature=temperature,
303
+ max_tokens=max_tokens,
304
+ stop=stop,
305
+ **kwargs
306
+ ):
307
+ yield chunk
308
+
309
+ def speech_to_text(self,
310
+ audio_file: str,
311
+ language: str = None,
312
+ prompt: str = None,
313
+ **kwargs) -> ModelResponse:
314
+ """Convert speech to text.
315
+
316
+ Args:
317
+ audio_file: Path to audio file or file object.
318
+ language: Audio language, optional.
319
+ prompt: Transcription prompt, optional.
320
+ **kwargs: Other parameters.
321
+
322
+ Returns:
323
+ ModelResponse: Unified model response object, with content field containing the transcription result.
324
+
325
+ Raises:
326
+ LLMResponseError: When LLM response error occurs.
327
+ NotImplementedError: When provider does not support speech to text conversion.
328
+ """
329
+ return self.provider.speech_to_text(
330
+ audio_file=audio_file,
331
+ language=language,
332
+ prompt=prompt,
333
+ **kwargs
334
+ )
335
+
336
+ async def aspeech_to_text(self,
337
+ audio_file: str,
338
+ language: str = None,
339
+ prompt: str = None,
340
+ **kwargs) -> ModelResponse:
341
+ """Asynchronously convert speech to text.
342
+
343
+ Args:
344
+ audio_file: Path to audio file or file object.
345
+ language: Audio language, optional.
346
+ prompt: Transcription prompt, optional.
347
+ **kwargs: Other parameters.
348
+
349
+ Returns:
350
+ ModelResponse: Unified model response object, with content field containing the transcription result.
351
+
352
+ Raises:
353
+ LLMResponseError: When LLM response error occurs.
354
+ NotImplementedError: When provider does not support speech to text conversion.
355
+ """
356
+ return await self.provider.aspeech_to_text(
357
+ audio_file=audio_file,
358
+ language=language,
359
+ prompt=prompt,
360
+ **kwargs
361
+ )
362
+
363
+
364
+ def register_llm_provider(provider: str, provider_class: type):
365
+ """Register a custom LLM provider.
366
+
367
+ Args:
368
+ provider: Provider name.
369
+ provider_class: Provider class, must inherit from LLMProviderBase.
370
+ """
371
+ if not issubclass(provider_class, LLMProviderBase):
372
+ raise TypeError("provider_class must be a subclass of LLMProviderBase")
373
+ PROVIDER_CLASSES[provider] = provider_class
374
+
375
+
376
+ def conf_contains_key(conf: Union[ConfigDict, AgentConfig], key: str) -> bool:
377
+ """Check if conf contains key.
378
+ Args:
379
+ conf: Config object.
380
+ key: Key to check.
381
+ Returns:
382
+ bool: Whether conf contains key.
383
+ """
384
+ if not conf:
385
+ return False
386
+ if type(conf).__name__ == 'AgentConfig':
387
+ return hasattr(conf, key)
388
+ else:
389
+ return key in conf
390
+
391
+
392
+ def get_llm_model(conf: Union[ConfigDict, AgentConfig] = None,
393
+ custom_provider: LLMProviderBase = None,
394
+ **kwargs) -> Union[LLMModel, 'ChatOpenAI']:
395
+ """Get a unified LLM model instance.
396
+
397
+ Args:
398
+ conf: Agent configuration, if provided, create model based on configuration.
399
+ custom_provider: Custom LLMProviderBase instance, if provided, use it directly.
400
+ **kwargs: Other parameters, may include:
401
+ - base_url: Specify model endpoint.
402
+ - api_key: API key.
403
+ - model_name: Model name.
404
+ - temperature: Temperature parameter.
405
+
406
+ Returns:
407
+ Unified model interface.
408
+ """
409
+ # Create and return LLMModel instance directly
410
+ llm_provider = conf.llm_provider if conf_contains_key(
411
+ conf, "llm_provider") else None
412
+
413
+ if (llm_provider == "chatopenai"):
414
+ from langchain_openai import ChatOpenAI
415
+
416
+ base_url = kwargs.get("base_url") or (
417
+ conf.llm_base_url if conf_contains_key(conf, "llm_base_url") else None)
418
+ model_name = kwargs.get("model_name") or (
419
+ conf.llm_model_name if conf_contains_key(conf, "llm_model_name") else None)
420
+ api_key = kwargs.get("api_key") or (
421
+ conf.llm_api_key if conf_contains_key(conf, "llm_api_key") else None)
422
+
423
+ return ChatOpenAI(
424
+ model=model_name,
425
+ temperature=kwargs.get("temperature", conf.llm_temperature if conf_contains_key(
426
+ conf, "llm_temperature") else 0.0),
427
+ base_url=base_url,
428
+ api_key=api_key,
429
+ )
430
+
431
+ return LLMModel(conf=conf, custom_provider=custom_provider, **kwargs)
432
+
433
+
434
+ def call_llm_model(
435
+ llm_model: LLMModel,
436
+ messages: List[Dict[str, str]],
437
+ temperature: float = 0.0,
438
+ max_tokens: int = None,
439
+ stop: List[str] = None,
440
+ stream: bool = False,
441
+ **kwargs
442
+ ) -> Union[ModelResponse, Generator[ModelResponse, None, None]]:
443
+ """Convenience function to call LLM model.
444
+
445
+ Args:
446
+ llm_model: LLM model instance.
447
+ messages: Message list.
448
+ temperature: Temperature parameter.
449
+ max_tokens: Maximum number of tokens to generate.
450
+ stop: List of stop sequences.
451
+ stream: Whether to return a streaming response.
452
+ **kwargs: Other parameters.
453
+
454
+ Returns:
455
+ Model response or response generator.
456
+ """
457
+ if stream:
458
+ return llm_model.stream_completion(
459
+ messages=messages,
460
+ temperature=temperature,
461
+ max_tokens=max_tokens,
462
+ stop=stop,
463
+ **kwargs
464
+ )
465
+ else:
466
+ return llm_model.completion(
467
+ messages=messages,
468
+ temperature=temperature,
469
+ max_tokens=max_tokens,
470
+ stop=stop,
471
+ **kwargs
472
+ )
473
+
474
+
475
+ async def acall_llm_model(
476
+ llm_model: LLMModel,
477
+ messages: List[Dict[str, str]],
478
+ temperature: float = 0.0,
479
+ max_tokens: int = None,
480
+ stop: List[str] = None,
481
+ stream: bool = False,
482
+ **kwargs
483
+ ) -> ModelResponse:
484
+ """Convenience function to asynchronously call LLM model.
485
+
486
+ Args:
487
+ llm_model: LLM model instance.
488
+ messages: Message list.
489
+ temperature: Temperature parameter.
490
+ max_tokens: Maximum number of tokens to generate.
491
+ stop: List of stop sequences.
492
+ stream: Whether to return a streaming response.
493
+ **kwargs: Other parameters.
494
+
495
+ Returns:
496
+ Model response or response generator.
497
+ """
498
+ return await llm_model.acompletion(
499
+ messages=messages,
500
+ temperature=temperature,
501
+ max_tokens=max_tokens,
502
+ stop=stop,
503
+ **kwargs
504
+ )
505
+
506
+
507
+ async def acall_llm_model_stream(
508
+ llm_model: LLMModel,
509
+ messages: List[Dict[str, str]],
510
+ temperature: float = 0.0,
511
+ max_tokens: int = None,
512
+ stop: List[str] = None,
513
+ **kwargs
514
+ ) -> AsyncGenerator[ModelResponse, None]:
515
+ async for chunk in llm_model.astream_completion(
516
+ messages=messages,
517
+ temperature=temperature,
518
+ max_tokens=max_tokens,
519
+ stop=stop,
520
+ **kwargs
521
+ ):
522
+ yield chunk
523
+
524
+
525
+ def speech_to_text(
526
+ llm_model: LLMModel,
527
+ audio_file: str,
528
+ language: str = None,
529
+ prompt: str = None,
530
+ **kwargs
531
+ ) -> ModelResponse:
532
+ """Convenience function to convert speech to text.
533
+
534
+ Args:
535
+ llm_model: LLM model instance.
536
+ audio_file: Path to audio file or file object.
537
+ language: Audio language, optional.
538
+ prompt: Transcription prompt, optional.
539
+ **kwargs: Other parameters.
540
+
541
+ Returns:
542
+ ModelResponse: Unified model response object, with content field containing the transcription result.
543
+ """
544
+ if llm_model.provider_name != "openai":
545
+ raise NotImplementedError(
546
+ f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}")
547
+
548
+ return llm_model.speech_to_text(
549
+ audio_file=audio_file,
550
+ language=language,
551
+ prompt=prompt,
552
+ **kwargs
553
+ )
554
+
555
+
556
+ async def aspeech_to_text(
557
+ llm_model: LLMModel,
558
+ audio_file: str,
559
+ language: str = None,
560
+ prompt: str = None,
561
+ **kwargs
562
+ ) -> ModelResponse:
563
+ """Convenience function to asynchronously convert speech to text.
564
+
565
+ Args:
566
+ llm_model: LLM model instance.
567
+ audio_file: Path to audio file or file object.
568
+ language: Audio language, optional.
569
+ prompt: Transcription prompt, optional.
570
+ **kwargs: Other parameters.
571
+
572
+ Returns:
573
+ ModelResponse: Unified model response object, with content field containing the transcription result.
574
+ """
575
+ if llm_model.provider_name != "openai":
576
+ raise NotImplementedError(
577
+ f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}")
578
+
579
+ return await llm_model.aspeech_to_text(
580
+ audio_file=audio_file,
581
+ language=language,
582
+ prompt=prompt,
583
+ **kwargs
584
+ )
aworld/models/llm_http_handler.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTTP handler for LLM providers.
2
+
3
+ This module provides a generic HTTP handler for making requests to LLM providers
4
+ when direct SDK usage is not desired.
5
+ """
6
+
7
+ import json
8
+ import asyncio
9
+ import random
10
+ import time
11
+ from typing import Any, Dict, List, Optional, Union, Generator, AsyncGenerator
12
+ import requests
13
+ from requests import HTTPError
14
+
15
+ from aworld.logs.util import logger
16
+ from aworld.utils import import_package
17
+
18
+ class LLMHTTPHandler:
19
+ """HTTP handler for LLM providers.
20
+
21
+ This class provides methods to make HTTP requests to LLM providers
22
+ instead of using their SDKs directly.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ base_url: str,
28
+ api_key: str,
29
+ model_name: str,
30
+ headers: Optional[Dict[str, str]] = None,
31
+ timeout: int = 180,
32
+ max_retries: int = 3,
33
+ ) -> None:
34
+ """Initialize the HTTP handler.
35
+
36
+ Args:
37
+ base_url: Base URL for the LLM API.
38
+ api_key: API key for authentication.
39
+ model_name: Name of the model to use.
40
+ headers: Additional headers to include in requests.
41
+ timeout: Request timeout in seconds.
42
+ max_retries: Maximum number of retries for failed requests.
43
+ """
44
+ import_package("aiohttp")
45
+ self.base_url = base_url.rstrip("/")
46
+ self.api_key = api_key
47
+ self.model_name = model_name
48
+ self.timeout = timeout
49
+ self.max_retries = max_retries
50
+
51
+ # Set up default headers
52
+ self.headers = {
53
+ "Content-Type": "application/json",
54
+ "Authorization": f"Bearer {api_key}",
55
+ }
56
+ if headers:
57
+ self.headers.update(headers)
58
+
59
+ def _parse_sse_line(self, line: bytes) -> Optional[Dict[str, Any]]:
60
+ """Parse a Server-Sent Events (SSE) line.
61
+
62
+ Args:
63
+ line: Raw SSE line.
64
+
65
+ Returns:
66
+ Parsed JSON data if successful, None otherwise.
67
+ """
68
+ try:
69
+ # Remove 'data: ' prefix if present
70
+ line_str = line.decode('utf-8').strip()
71
+ if line_str.startswith('data: '):
72
+ line_str = line_str[6:]
73
+
74
+ # Skip empty lines
75
+ if not line_str:
76
+ return None
77
+
78
+ return json.loads(line_str)
79
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
80
+ logger.warning(f"Failed to parse SSE line: {line}, error: {str(e)}")
81
+ return None
82
+
83
+ def _make_request(
84
+ self,
85
+ endpoint: str,
86
+ data: Dict[str, Any],
87
+ stream: bool = False,
88
+ headers: Optional[Dict[str, str]] = None,
89
+ ) -> Union[Dict[str, Any], Generator[Dict[str, Any], None, None]]:
90
+ """Make a synchronous HTTP request.
91
+
92
+ Args:
93
+ endpoint: API endpoint to call.
94
+ data: Request data to send.
95
+ stream: Whether to stream the response.
96
+
97
+ Returns:
98
+ Response data or generator of response chunks.
99
+
100
+ Raises:
101
+ requests.exceptions.RequestException: If the request fails.
102
+ """
103
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
104
+ request_headers = self.headers.copy()
105
+ if headers:
106
+ request_headers.update(headers)
107
+
108
+
109
+ try:
110
+ if stream:
111
+ response = requests.post(
112
+ url,
113
+ headers=request_headers,
114
+ json=data,
115
+ stream=True,
116
+ timeout=self.timeout,
117
+ )
118
+ response.raise_for_status()
119
+
120
+ def generate_chunks():
121
+ for line in response.iter_lines():
122
+ if line:
123
+ line_str = line.decode('utf-8').strip()
124
+ if line_str.startswith('data: '):
125
+ line_content = line_str[6:]
126
+
127
+ if line_content == "[DONE]":
128
+ yield {"status": "done", "message": "Stream completed"}
129
+ break
130
+ elif line_content == "[REVOKE]":
131
+ yield {"status": "revoke", "message": "Content should be revoked"}
132
+ continue
133
+ elif line_content == "[FAIL]":
134
+ yield {"status": "fail", "message": "Request failed"}
135
+ break
136
+ elif line_content.startswith("[FAIL]_stream was reset: CANCEL"):
137
+ yield {"status": "cancel", "message": "Stream was cancelled"}
138
+ break
139
+
140
+ chunk = self._parse_sse_line(line)
141
+ if chunk is not None:
142
+ yield chunk
143
+ return generate_chunks()
144
+ else:
145
+ response = requests.post(
146
+ url,
147
+ headers=request_headers,
148
+ json=data,
149
+ timeout=self.timeout,
150
+ )
151
+ response.raise_for_status()
152
+ return response.json()
153
+ except Exception as e:
154
+ logger.error(f"Error in HttpHandler: {str(e)}")
155
+ raise
156
+
157
+ async def _make_async_request_stream(
158
+ self,
159
+ endpoint: str,
160
+ data: Dict[str, Any],
161
+ headers: Optional[Dict[str, str]] = None,
162
+ ) -> AsyncGenerator[Dict[str, Any], None]:
163
+ """Make an asynchronous streaming HTTP request.
164
+
165
+ Args:
166
+ endpoint: API endpoint to call.
167
+ data: Request data to send.
168
+
169
+ Yields:
170
+ Response chunks.
171
+
172
+ Raises:
173
+ aiohttp.ClientError: If the request fails.
174
+ """
175
+ import aiohttp
176
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
177
+ request_headers = self.headers.copy()
178
+ if headers:
179
+ request_headers.update(headers)
180
+
181
+ # Create an independent session and keep it open
182
+ session = aiohttp.ClientSession()
183
+ try:
184
+ response = await session.post(
185
+ url,
186
+ headers=request_headers,
187
+ json=data,
188
+ timeout=self.timeout,
189
+ )
190
+ response.raise_for_status()
191
+
192
+ # Implement async generator directly
193
+ async for line in response.content:
194
+ if line:
195
+ line_str = line.decode('utf-8').strip()
196
+ if line_str.startswith('data: '):
197
+ line_content = line_str[6:]
198
+
199
+ if line_content == "[DONE]":
200
+ yield {"status": "done", "message": "Stream completed"}
201
+ break
202
+ elif line_content == "[REVOKE]":
203
+ yield {"status": "revoke", "message": "Content should be revoked"}
204
+ continue
205
+ elif line_content == "[FAIL]":
206
+ yield {"status": "fail", "message": "Request failed"}
207
+ break
208
+ elif line_content.startswith("[FAIL]_stream was reset: CANCEL"):
209
+ yield {"status": "cancel", "message": "Stream was cancelled"}
210
+ break
211
+
212
+ chunk = self._parse_sse_line(line)
213
+ if chunk is not None:
214
+ yield chunk
215
+ except Exception as e:
216
+ logger.error(f"Error in stream: {str(e)}")
217
+ raise
218
+ finally:
219
+ # Ensure the session is eventually closed
220
+ await session.close()
221
+
222
+ async def _make_async_request(
223
+ self,
224
+ endpoint: str,
225
+ data: Dict[str, Any],
226
+ headers: Optional[Dict[str, str]] = None,
227
+ ) -> Dict[str, Any]:
228
+ """Make an asynchronous non-streaming HTTP request.
229
+
230
+ Args:
231
+ endpoint: API endpoint to call.
232
+ data: Request data to send.
233
+
234
+ Returns:
235
+ Response data.
236
+
237
+ Raises:
238
+ aiohttp.ClientError: If the request fails.
239
+ """
240
+ import aiohttp
241
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
242
+ request_headers = self.headers.copy()
243
+ if headers:
244
+ request_headers.update(headers)
245
+
246
+ async with aiohttp.ClientSession() as session:
247
+ async with session.post(
248
+ url,
249
+ headers=request_headers,
250
+ json=data,
251
+ timeout=self.timeout,
252
+ ) as response:
253
+ response.raise_for_status()
254
+ return await response.json()
255
+
256
+ def sync_call(
257
+ self,
258
+ data: Dict[str, Any],
259
+ endpoint: str = None,
260
+ headers: Optional[Dict[str, str]] = None,
261
+ ) -> Dict[str, Any]:
262
+ """Make a synchronous completion request.
263
+
264
+ Args:
265
+ data: Request data.
266
+
267
+ Returns:
268
+ Response data.
269
+ """
270
+ logger.debug(f"sync_call request data: {data}")
271
+
272
+ if not endpoint:
273
+ endpoint = "chat/completions"
274
+
275
+ retries = 0
276
+ while retries < self.max_retries:
277
+ try:
278
+ response = self._make_request(endpoint, data, headers=headers)
279
+ return response
280
+ except Exception as e:
281
+ last_error = e
282
+ retries += 1
283
+ if retries < self.max_retries:
284
+ logger.warning(f"Request failed, retrying ({retries}/{self.max_retries}): {str(e)}")
285
+ # Exponential backoff with jitter
286
+ backoff = min(2 ** retries + random.uniform(0, 1), 10)
287
+ time.sleep(backoff)
288
+ else:
289
+ logger.error(f"Request failed after {self.max_retries} retries: {str(e)}")
290
+ raise last_error
291
+
292
+ async def async_call(
293
+ self,
294
+ data: Dict[str, Any],
295
+ endpoint: str = None,
296
+ headers: Optional[Dict[str, str]] = None,
297
+ ) -> Dict[str, Any]:
298
+ """Make an asynchronous completion request.
299
+
300
+ Args:
301
+ data: Request data.
302
+
303
+ Returns:
304
+ Response data.
305
+ """
306
+ import aiohttp
307
+ logger.info(f"async_call request data: {data}")
308
+
309
+ retries = 0
310
+ last_error = None
311
+ if not endpoint:
312
+ endpoint = "chat/completions"
313
+
314
+ while retries < self.max_retries:
315
+ try:
316
+ response = await self._make_async_request(endpoint, data, headers=headers)
317
+ return response
318
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
319
+ last_error = e
320
+ retries += 1
321
+ if retries < self.max_retries:
322
+ logger.warning(f"Request failed, retrying ({retries}/{self.max_retries}): {str(e)}")
323
+ # Exponential backoff with jitter
324
+ backoff = min(2 ** retries + random.uniform(0, 1), 10)
325
+ await asyncio.sleep(backoff)
326
+ else:
327
+ logger.error(f"Request failed after {self.max_retries} retries: {str(e)}")
328
+ raise last_error
329
+
330
+ def sync_stream_call(
331
+ self,
332
+ data: Dict[str, Any],
333
+ endpoint: str = None,
334
+ headers: Optional[Dict[str, str]] = None,
335
+ ) -> Generator[Dict[str, Any], None, None]:
336
+ """Make a synchronous streaming completion request.
337
+
338
+ Args:
339
+ data: Request data.
340
+
341
+ Yields:
342
+ Response chunks.
343
+ """
344
+ data["stream"] = True
345
+ logger.info(f"sync_stream_call request data: {data}")
346
+ retries = 0
347
+
348
+ while retries < self.max_retries:
349
+ try:
350
+ for chunk in self._make_request(endpoint or "chat/completions", data, stream=True, headers=headers):
351
+ yield chunk
352
+ return # Exit after completing stream processing
353
+ except Exception as e:
354
+ last_error = e
355
+ retries += 1
356
+ if retries < self.max_retries:
357
+ logger.warning(f"Stream connection failed, retrying ({retries}/{self.max_retries}): {str(e)}")
358
+ else:
359
+ logger.error(f"Stream connection failed after {self.max_retries} retries: {str(e)}")
360
+ raise last_error
361
+
362
+
363
+ async def async_stream_call(
364
+ self,
365
+ data: Dict[str, Any],
366
+ endpoint: str = None,
367
+ headers: Optional[Dict[str, str]] = None,
368
+ ) -> AsyncGenerator[Dict[str, Any], None]:
369
+ """Make an asynchronous streaming completion request.
370
+
371
+ Args:
372
+ data: Request data.
373
+
374
+ Yields:
375
+ Response chunks.
376
+ """
377
+ import aiohttp
378
+ data["stream"] = True
379
+ logger.info(f"async_stream_call request data: {data}")
380
+
381
+ retries = 0
382
+ last_error = None
383
+
384
+ while retries < self.max_retries:
385
+ try:
386
+ async for chunk in self._make_async_request_stream(endpoint or "chat/completions", data, headers=headers):
387
+ yield chunk
388
+ return # Exit after completing stream processing
389
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
390
+ last_error = e
391
+ retries += 1
392
+ if retries < self.max_retries:
393
+ logger.warning(f"Stream connection failed, retrying ({retries}/{self.max_retries}): {str(e)}")
394
+ await asyncio.sleep(1) # Wait one second before retrying
395
+ else:
396
+ logger.error(f"Stream connection failed after {self.max_retries} retries: {str(e)}")
397
+ raise last_error
aworld/models/model_response.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+ import json
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class LLMResponseError(Exception):
7
+ """Represents an error in LLM response.
8
+
9
+ Attributes:
10
+ message: Error message
11
+ model: Model name
12
+ response: Original response object
13
+ """
14
+
15
+ def __init__(self, message: str, model: str = "unknown", response: Any = None):
16
+ """
17
+ Initialize LLM response error
18
+
19
+ Args:
20
+ message: Error message
21
+ model: Model name
22
+ response: Original response object
23
+ """
24
+ self.message = message
25
+ self.model = model
26
+ self.response = response
27
+ super().__init__(f"LLM Error ({model}): {message}. Response: {response}")
28
+
29
+
30
+ class Function(BaseModel):
31
+ """
32
+ Represents a function call made by a model
33
+ """
34
+ name: str
35
+ arguments: str = None
36
+
37
+
38
+ class ToolCall(BaseModel):
39
+ """
40
+ Represents a tool call made by a model
41
+ """
42
+
43
+ id: str
44
+ type: str = "function"
45
+ function: Function = None
46
+
47
+ # name: str = None
48
+ # arguments: str = None
49
+
50
+ @classmethod
51
+ def from_dict(cls, data: Dict[str, Any]) -> 'ToolCall':
52
+ """
53
+ Create ToolCall from dictionary representation
54
+
55
+ Args:
56
+ data: Dictionary containing tool call data
57
+
58
+ Returns:
59
+ ToolCall object
60
+ """
61
+ if not data:
62
+ return None
63
+
64
+ tool_id = data.get('id', f"call_{hash(str(data)) & 0xffffffff:08x}")
65
+ tool_type = data.get('type', 'function')
66
+
67
+ function_data = data.get('function', {})
68
+ name = function_data.get('name')
69
+
70
+ arguments = function_data.get('arguments')
71
+ # Ensure arguments is a string
72
+ if arguments is not None and not isinstance(arguments, str):
73
+ arguments = json.dumps(arguments, ensure_ascii=False)
74
+
75
+ function = Function(name=name, arguments=arguments)
76
+
77
+ return cls(
78
+ id=tool_id,
79
+ type=tool_type,
80
+ function=function,
81
+ # name=name,
82
+ # arguments=arguments,
83
+ )
84
+
85
+ def to_dict(self) -> Dict[str, Any]:
86
+ """
87
+ Convert ToolCall to dictionary representation
88
+
89
+ Returns:
90
+ Dictionary representation
91
+ """
92
+ return {
93
+ "id": self.id,
94
+ "type": self.type,
95
+ "function": {
96
+ "name": self.function.name,
97
+ "arguments": self.function.arguments
98
+ }
99
+ }
100
+
101
+ def __repr__(self):
102
+ return json.dumps(self.to_dict(), ensure_ascii=False)
103
+
104
+ def __iter__(self):
105
+ """
106
+ Make ToolCall dict-like for JSON serialization
107
+ """
108
+ yield from self.to_dict().items()
109
+
110
+
111
+ class ModelResponse:
112
+ """
113
+ Unified model response class for encapsulating responses from different LLM providers
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ id: str,
119
+ model: str,
120
+ content: str = None,
121
+ tool_calls: List[ToolCall] = None,
122
+ usage: Dict[str, int] = None,
123
+ error: str = None,
124
+ raw_response: Any = None,
125
+ message: Dict[str, Any] = None
126
+ ):
127
+ """
128
+ Initialize ModelResponse object
129
+
130
+ Args:
131
+ id: Response ID
132
+ model: Model name used
133
+ content: Generated text content
134
+ tool_calls: List of tool calls
135
+ usage: Usage statistics (token counts, etc.)
136
+ error: Error message (if any)
137
+ raw_response: Original response object
138
+ message: Complete message object, can be used for subsequent API calls
139
+ """
140
+ self.id = id
141
+ self.model = model
142
+ self.content = content
143
+ self.tool_calls = tool_calls
144
+ self.usage = usage or {
145
+ "completion_tokens": 0,
146
+ "prompt_tokens": 0,
147
+ "total_tokens": 0
148
+ }
149
+ self.error = error
150
+ self.raw_response = raw_response
151
+
152
+ # If message is not provided, construct one from other fields
153
+ if message is None:
154
+ self.message = {
155
+ "role": "assistant",
156
+ "content": content
157
+ }
158
+
159
+ if tool_calls:
160
+ self.message["tool_calls"] = [tool_call.to_dict() for tool_call in tool_calls]
161
+ else:
162
+ self.message = message
163
+
164
+ @classmethod
165
+ def from_openai_response(cls, response: Any) -> 'ModelResponse':
166
+ """
167
+ Create ModelResponse from OpenAI response object
168
+
169
+ Args:
170
+ response: OpenAI response object
171
+
172
+ Returns:
173
+ ModelResponse object
174
+
175
+ Raises:
176
+ LLMResponseError: When LLM response error occurs
177
+ """
178
+ # Handle error cases
179
+ if hasattr(response, 'error') or (isinstance(response, dict) and response.get('error')):
180
+ error_msg = response.error if hasattr(response, 'error') else response.get('error', 'Unknown error')
181
+ raise LLMResponseError(
182
+ error_msg,
183
+ response.model if hasattr(response, 'model') else response.get('model', 'unknown'),
184
+ response
185
+ )
186
+
187
+ # Normal case
188
+ message = None
189
+ if hasattr(response, 'choices') and response.choices:
190
+ message = response.choices[0].message
191
+ elif isinstance(response, dict) and response.get('choices'):
192
+ message = response['choices'][0].get('message', {})
193
+
194
+ if not message:
195
+ raise LLMResponseError(
196
+ "No message found in response",
197
+ response.model if hasattr(response, 'model') else response.get('model', 'unknown'),
198
+ response
199
+ )
200
+
201
+ # Extract usage information
202
+ usage = {}
203
+ if hasattr(response, 'usage'):
204
+ usage = {
205
+ "completion_tokens": response.usage.completion_tokens if hasattr(response.usage,
206
+ 'completion_tokens') else 0,
207
+ "prompt_tokens": response.usage.prompt_tokens if hasattr(response.usage, 'prompt_tokens') else 0,
208
+ "total_tokens": response.usage.total_tokens if hasattr(response.usage, 'total_tokens') else 0
209
+ }
210
+ elif isinstance(response, dict) and response.get('usage'):
211
+ usage = response['usage']
212
+
213
+ # Build message object
214
+ message_dict = {}
215
+ if hasattr(message, '__dict__'):
216
+ # Convert object to dictionary
217
+ for key, value in message.__dict__.items():
218
+ if not key.startswith('_'):
219
+ message_dict[key] = value
220
+ elif isinstance(message, dict):
221
+ message_dict = message
222
+ else:
223
+ # Extract common properties
224
+ message_dict = {
225
+ "role": "assistant",
226
+ "content": message.content if hasattr(message, 'content') else "",
227
+ "tool_calls": message.tool_calls if hasattr(message, 'tool_calls') else None,
228
+ }
229
+
230
+ message_dict["content"] = '' if message_dict.get('content') is None else message_dict.get('content', '')
231
+
232
+ # Process tool calls
233
+ processed_tool_calls = []
234
+ raw_tool_calls = message.tool_calls if hasattr(message, 'tool_calls') else message_dict.get('tool_calls')
235
+ if raw_tool_calls:
236
+ for tool_call in raw_tool_calls:
237
+ if isinstance(tool_call, dict):
238
+ processed_tool_calls.append(ToolCall.from_dict(tool_call))
239
+ else:
240
+ # Handle OpenAI object
241
+ tool_call_dict = {
242
+ "id": tool_call.id if hasattr(tool_call,
243
+ 'id') else f"call_{hash(str(tool_call)) & 0xffffffff:08x}",
244
+ "type": tool_call.type if hasattr(tool_call, 'type') else "function"
245
+ }
246
+
247
+ if hasattr(tool_call, 'function'):
248
+ function = tool_call.function
249
+ tool_call_dict["function"] = {
250
+ "name": function.name if hasattr(function, 'name') else None,
251
+ "arguments": function.arguments if hasattr(function, 'arguments') else None
252
+ }
253
+ processed_tool_calls.append(ToolCall.from_dict(tool_call_dict))
254
+
255
+ if message_dict and processed_tool_calls:
256
+ message_dict["tool_calls"] = [tool_call.to_dict() for tool_call in processed_tool_calls]
257
+
258
+ # Create and return ModelResponse
259
+ return cls(
260
+ id=response.id if hasattr(response, 'id') else response.get('id', 'unknown'),
261
+ model=response.model if hasattr(response, 'model') else response.get('model', 'unknown'),
262
+ content=message.content if hasattr(message, 'content') else message.get('content') or "",
263
+ tool_calls=processed_tool_calls or None,
264
+ usage=usage,
265
+ raw_response=response,
266
+ message=message_dict
267
+ )
268
+
269
+ @classmethod
270
+ def from_openai_stream_chunk(cls, chunk: Any) -> 'ModelResponse':
271
+ """
272
+ Create ModelResponse from OpenAI stream response chunk
273
+
274
+ Args:
275
+ chunk: OpenAI stream chunk
276
+
277
+ Returns:
278
+ ModelResponse object
279
+
280
+ Raises:
281
+ LLMResponseError: When LLM response error occurs
282
+ """
283
+ # Handle error cases
284
+ if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')):
285
+ error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error')
286
+ raise LLMResponseError(
287
+ error_msg,
288
+ chunk.model if hasattr(chunk, 'model') else chunk.get('model', 'unknown'),
289
+ chunk
290
+ )
291
+
292
+ # Handle finish reason chunk (end of stream)
293
+ if hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].finish_reason:
294
+ return cls(
295
+ id=chunk.id if hasattr(chunk, 'id') else chunk.get('id', 'unknown'),
296
+ model=chunk.model if hasattr(chunk, 'model') else chunk.get('model', 'unknown'),
297
+ content=None,
298
+ raw_response=chunk,
299
+ message={"role": "assistant", "content": "", "finish_reason": chunk.choices[0].finish_reason}
300
+ )
301
+
302
+ # Normal chunk with delta content
303
+ content = None
304
+ processed_tool_calls = []
305
+
306
+ if hasattr(chunk, 'choices') and chunk.choices:
307
+ delta = chunk.choices[0].delta
308
+ if hasattr(delta, 'content') and delta.content:
309
+ content = delta.content
310
+ if hasattr(delta, 'tool_calls') and delta.tool_calls:
311
+ raw_tool_calls = delta.tool_calls
312
+ for tool_call in raw_tool_calls:
313
+ if isinstance(tool_call, dict):
314
+ processed_tool_calls.append(ToolCall.from_dict(tool_call))
315
+ else:
316
+ # Handle OpenAI object
317
+ tool_call_dict = {
318
+ "id": tool_call.id if hasattr(tool_call,
319
+ 'id') else f"call_{hash(str(tool_call)) & 0xffffffff:08x}",
320
+ "type": tool_call.type if hasattr(tool_call, 'type') else "function"
321
+ }
322
+
323
+ if hasattr(tool_call, 'function'):
324
+ function = tool_call.function
325
+ tool_call_dict["function"] = {
326
+ "name": function.name if hasattr(function, 'name') else None,
327
+ "arguments": function.arguments if hasattr(function, 'arguments') else None
328
+ }
329
+
330
+ processed_tool_calls.append(ToolCall.from_dict(tool_call_dict))
331
+ elif isinstance(chunk, dict) and chunk.get('choices'):
332
+ delta = chunk['choices'][0].get('delta', {})
333
+ if not delta:
334
+ delta = chunk['choices'][0].get('message', {})
335
+ content = delta.get('content')
336
+ raw_tool_calls = delta.get('tool_calls')
337
+ if raw_tool_calls:
338
+ for tool_call in raw_tool_calls:
339
+ processed_tool_calls.append(ToolCall.from_dict(tool_call))
340
+
341
+ # Extract usage information
342
+ usage = {}
343
+ if hasattr(chunk, 'usage'):
344
+ usage = {
345
+ "completion_tokens": chunk.usage.completion_tokens if hasattr(chunk.usage, 'completion_tokens') else 0,
346
+ "prompt_tokens": chunk.usage.prompt_tokens if hasattr(chunk.usage, 'prompt_tokens') else 0,
347
+ "total_tokens": chunk.usage.total_tokens if hasattr(chunk.usage, 'total_tokens') else 0
348
+ }
349
+ elif isinstance(chunk, dict) and chunk.get('usage'):
350
+ usage = chunk['usage']
351
+
352
+ # Create message object
353
+ message = {
354
+ "role": "assistant",
355
+ "content": content or "",
356
+ "tool_calls": [tool_call.to_dict() for tool_call in processed_tool_calls] if processed_tool_calls else None,
357
+ "is_chunk": True
358
+ }
359
+
360
+ # Create and return ModelResponse
361
+ return cls(
362
+ id=chunk.id if hasattr(chunk, 'id') else chunk.get('id', 'unknown'),
363
+ model=chunk.model if hasattr(chunk, 'model') else chunk.get('model', 'unknown'),
364
+ content=content,
365
+ tool_calls=processed_tool_calls or None,
366
+ usage=usage,
367
+ raw_response=chunk,
368
+ message=message
369
+ )
370
+
371
+ @classmethod
372
+ def from_anthropic_stream_chunk(cls, chunk: Any) -> 'ModelResponse':
373
+ """
374
+ Create ModelResponse from Anthropic stream response chunk
375
+
376
+ Args:
377
+ chunk: Anthropic stream chunk
378
+
379
+ Returns:
380
+ ModelResponse object
381
+
382
+ Raises:
383
+ LLMResponseError: When LLM response error occurs
384
+ """
385
+ try:
386
+ # Handle error cases
387
+ if not chunk or (isinstance(chunk, dict) and chunk.get('error')):
388
+ error_msg = chunk.get('error', 'Unknown error') if isinstance(chunk, dict) else 'Empty response'
389
+ raise LLMResponseError(
390
+ error_msg,
391
+ chunk.model if hasattr(chunk, 'model') else chunk.get('model', 'unknown'),
392
+ chunk)
393
+
394
+ # Handle stop reason (end of stream)
395
+ if hasattr(chunk, 'stop_reason') and chunk.stop_reason:
396
+ return cls(
397
+ id=chunk.id if hasattr(chunk, 'id') else 'unknown',
398
+ model=chunk.model if hasattr(chunk, 'model') else 'claude',
399
+ content=None,
400
+ raw_response=chunk,
401
+ message={"role": "assistant", "content": "", "stop_reason": chunk.stop_reason}
402
+ )
403
+
404
+ # Handle delta content
405
+ content = None
406
+ processed_tool_calls = []
407
+
408
+ if hasattr(chunk, 'delta') and chunk.delta:
409
+ delta = chunk.delta
410
+ if hasattr(delta, 'text') and delta.text:
411
+ content = delta.text
412
+ elif hasattr(delta, 'tool_use') and delta.tool_use:
413
+ tool_call_dict = {
414
+ "id": f"call_{delta.tool_use.id}",
415
+ "type": "function",
416
+ "function": {
417
+ "name": delta.tool_use.name,
418
+ "arguments": delta.tool_use.input if isinstance(delta.tool_use.input, str) else json.dumps(
419
+ delta.tool_use.input, ensure_ascii=False)
420
+ }
421
+ }
422
+ processed_tool_calls.append(ToolCall.from_dict(tool_call_dict))
423
+
424
+ # Create message object
425
+ message = {
426
+ "role": "assistant",
427
+ "content": content or "",
428
+ "tool_calls": [tool_call.to_dict() for tool_call in
429
+ processed_tool_calls] if processed_tool_calls else None,
430
+ "is_chunk": True
431
+ }
432
+
433
+ # Create and return ModelResponse
434
+ return cls(
435
+ id=chunk.id if hasattr(chunk, 'id') else 'unknown',
436
+ model=chunk.model if hasattr(chunk, 'model') else 'claude',
437
+ content=content,
438
+ tool_calls=processed_tool_calls or None,
439
+ raw_response=chunk,
440
+ message=message
441
+ )
442
+
443
+ except Exception as e:
444
+ if isinstance(e, LLMResponseError):
445
+ raise e
446
+ raise LLMResponseError(
447
+ f"Error processing Anthropic stream chunk: {str(e)}",
448
+ chunk.model if hasattr(chunk, 'model') else chunk.get('model', 'unknown'),
449
+ chunk)
450
+
451
+ @classmethod
452
+ def from_anthropic_response(cls, response: Any) -> 'ModelResponse':
453
+ """
454
+ Create ModelResponse from Anthropic original response object
455
+
456
+ Args:
457
+ response: Anthropic response object
458
+
459
+ Returns:
460
+ ModelResponse object
461
+
462
+ Raises:
463
+ LLMResponseError: When LLM response error occurs
464
+ """
465
+ try:
466
+ # Handle error cases
467
+ if not response or (isinstance(response, dict) and response.get('error')):
468
+ error_msg = response.get('error', 'Unknown error') if isinstance(response, dict) else 'Empty response'
469
+ raise LLMResponseError(
470
+ error_msg,
471
+ response.model if hasattr(response, 'model') else response.get('model', 'unknown'),
472
+ response)
473
+
474
+ # Build message content
475
+ message = {
476
+ "content": "",
477
+ "role": "assistant",
478
+ "tool_calls": None,
479
+ }
480
+
481
+ processed_tool_calls = []
482
+
483
+ if hasattr(response, 'content') and response.content:
484
+ for content_block in response.content:
485
+ if content_block.type == "text":
486
+ message["content"] = content_block.text
487
+ elif content_block.type == "tool_use":
488
+ tool_call_dict = {
489
+ "id": f"call_{content_block.id}",
490
+ "type": "function",
491
+ "function": {
492
+ "name": content_block.name,
493
+ "arguments": content_block.input if isinstance(content_block.input,
494
+ str) else json.dumps(content_block.input)
495
+ }
496
+ }
497
+ processed_tool_calls.append(ToolCall.from_dict(tool_call_dict))
498
+ else:
499
+ message["content"] = ""
500
+
501
+ if processed_tool_calls:
502
+ message["tool_calls"] = [tool_call.to_dict() for tool_call in processed_tool_calls]
503
+
504
+ # Extract usage information
505
+ usage = {
506
+ "completion_tokens": 0,
507
+ "prompt_tokens": 0,
508
+ "total_tokens": 0
509
+ }
510
+
511
+ if hasattr(response, 'usage'):
512
+ if hasattr(response.usage, 'output_tokens'):
513
+ usage["completion_tokens"] = response.usage.output_tokens
514
+ if hasattr(response.usage, 'input_tokens'):
515
+ usage["prompt_tokens"] = response.usage.input_tokens
516
+ if hasattr(response.usage, 'input_tokens') and hasattr(response.usage, 'output_tokens'):
517
+ usage["total_tokens"] = response.usage.input_tokens + response.usage.output_tokens
518
+
519
+ # Create ModelResponse
520
+ return cls(
521
+ id=response.id if hasattr(response,
522
+ 'id') else f"chatcmpl-anthropic-{hash(str(response)) & 0xffffffff:08x}",
523
+ model=response.model if hasattr(response, 'model') else "claude",
524
+ content=message["content"],
525
+ tool_calls=processed_tool_calls or None,
526
+ usage=usage,
527
+ raw_response=response,
528
+ message=message
529
+ )
530
+ except Exception as e:
531
+ if isinstance(e, LLMResponseError):
532
+ raise e
533
+ raise LLMResponseError(
534
+ f"Error processing Anthropic response: {str(e)}",
535
+ response.model if hasattr(response, 'model') else response.get('model', 'unknown'),
536
+ response)
537
+
538
+ @classmethod
539
+ def from_error(cls, error_msg: str, model: str = "unknown") -> 'ModelResponse':
540
+ """
541
+ Create ModelResponse from error message
542
+
543
+ Args:
544
+ error_msg: Error message
545
+ model: Model name
546
+
547
+ Returns:
548
+ ModelResponse object
549
+ """
550
+ return cls(
551
+ id="error",
552
+ model=model,
553
+ error=error_msg,
554
+ message={"role": "assistant", "content": f"Error: {error_msg}"}
555
+ )
556
+
557
+ def to_dict(self) -> Dict[str, Any]:
558
+ """
559
+ Convert ModelResponse to dictionary representation
560
+
561
+ Returns:
562
+ Dictionary representation
563
+ """
564
+ tool_calls_dict = None
565
+ if self.tool_calls:
566
+ tool_calls_dict = [tool_call.to_dict() for tool_call in self.tool_calls]
567
+
568
+ return {
569
+ "id": self.id,
570
+ "model": self.model,
571
+ "content": self.content,
572
+ "tool_calls": tool_calls_dict,
573
+ "usage": self.usage,
574
+ "error": self.error,
575
+ "message": self.message
576
+ }
577
+
578
+ def get_message(self) -> Dict[str, Any]:
579
+ """
580
+ Return message object that can be directly used for subsequent API calls
581
+
582
+ Returns:
583
+ Message object dictionary
584
+ """
585
+ return self.message
586
+
587
+ def serialize_tool_calls(self) -> List[Dict[str, Any]]:
588
+ """
589
+ Convert tool call objects to JSON format, handling OpenAI object types
590
+
591
+ Returns:
592
+ List[Dict[str, Any]]: Tool calls list in JSON format
593
+ """
594
+ if not self.tool_calls:
595
+ return []
596
+
597
+ result = []
598
+ for tool_call in self.tool_calls:
599
+ if hasattr(tool_call, 'to_dict'):
600
+ result.append(tool_call.to_dict())
601
+ elif isinstance(tool_call, dict):
602
+ result.append(tool_call)
603
+ else:
604
+ result.append(str(tool_call))
605
+ return result
606
+
607
+ def __repr__(self):
608
+ return json.dumps(self.to_dict(), ensure_ascii=False, indent=None,
609
+ default=lambda obj: obj.to_dict() if hasattr(obj, 'to_dict') else str(obj))
610
+
611
+ def _serialize_message(self) -> Dict[str, Any]:
612
+ """
613
+ Serialize message object
614
+
615
+ Returns:
616
+ Dict[str, Any]: Serialized message dictionary
617
+ """
618
+ if not self.message:
619
+ return {}
620
+
621
+ result = {}
622
+
623
+ # Copy basic fields
624
+ for key, value in self.message.items():
625
+ if key == 'tool_calls':
626
+ # Handle tool_calls
627
+ result[key] = self.serialize_tool_calls()
628
+ else:
629
+ result[key] = value
630
+
631
+ return result
aworld/models/openai_provider.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Generator, AsyncGenerator
3
+
4
+ from openai import OpenAI, AsyncOpenAI
5
+
6
+ from aworld.config.conf import ClientType
7
+ from aworld.core.llm_provider_base import LLMProviderBase
8
+ from aworld.models.llm_http_handler import LLMHTTPHandler
9
+ from aworld.models.model_response import ModelResponse, LLMResponseError
10
+ from aworld.logs.util import logger
11
+ from aworld.models.utils import usage_process
12
+
13
+
14
+ class OpenAIProvider(LLMProviderBase):
15
+ """OpenAI provider implementation.
16
+ """
17
+
18
+ def _init_provider(self):
19
+ """Initialize OpenAI provider.
20
+
21
+ Returns:
22
+ OpenAI provider instance.
23
+ """
24
+ # Get API key
25
+ api_key = self.api_key
26
+ if not api_key:
27
+ env_var = "OPENAI_API_KEY"
28
+ api_key = os.getenv(env_var, "")
29
+ if not api_key:
30
+ raise ValueError(
31
+ f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters")
32
+ base_url = self.base_url
33
+ if not base_url:
34
+ base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
35
+
36
+ self.is_http_provider = False
37
+ if self.kwargs.get("client_type", ClientType.SDK) == ClientType.HTTP:
38
+ logger.info(f"Using HTTP provider for OpenAI")
39
+ self.http_provider = LLMHTTPHandler(
40
+ base_url=base_url,
41
+ api_key=api_key,
42
+ model_name=self.model_name,
43
+ max_retries=self.kwargs.get("max_retries", 3)
44
+ )
45
+ self.is_http_provider = True
46
+ return self.http_provider
47
+ else:
48
+ return OpenAI(
49
+ api_key=api_key,
50
+ base_url=base_url,
51
+ timeout=self.kwargs.get("timeout", 180),
52
+ max_retries=self.kwargs.get("max_retries", 3)
53
+ )
54
+
55
+ def _init_async_provider(self):
56
+ """Initialize async OpenAI provider.
57
+
58
+ Returns:
59
+ Async OpenAI provider instance.
60
+ """
61
+ # Get API key
62
+ api_key = self.api_key
63
+ if not api_key:
64
+ env_var = "OPENAI_API_KEY"
65
+ api_key = os.getenv(env_var, "")
66
+ if not api_key:
67
+ raise ValueError(
68
+ f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters")
69
+ base_url = self.base_url
70
+ if not base_url:
71
+ base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
72
+
73
+ return AsyncOpenAI(
74
+ api_key=api_key,
75
+ base_url=base_url,
76
+ timeout=self.kwargs.get("timeout", 180),
77
+ max_retries=self.kwargs.get("max_retries", 3)
78
+ )
79
+
80
+ @classmethod
81
+ def supported_models(cls) -> list[str]:
82
+ return ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini", "deepseek-chat", "deepseek-reasoner",
83
+ r"qwq-.*", r"qwen-.*"]
84
+
85
+ def preprocess_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
86
+ """Preprocess messages, use OpenAI format directly.
87
+
88
+ Args:
89
+ messages: OpenAI format message list.
90
+
91
+ Returns:
92
+ Processed message list.
93
+ """
94
+ for message in messages:
95
+ if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]:
96
+ if message["content"] is None: message["content"] = ""
97
+ for tool_call in message["tool_calls"]:
98
+ if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call:
99
+ tool_call["function"] = {"name": tool_call["name"], "arguments": tool_call["arguments"]}
100
+
101
+ return messages
102
+
103
+ def postprocess_response(self, response: Any) -> ModelResponse:
104
+ """Process OpenAI response.
105
+
106
+ Args:
107
+ response: OpenAI response object.
108
+
109
+ Returns:
110
+ ModelResponse object.
111
+
112
+ Raises:
113
+ LLMResponseError: When LLM response error occurs.
114
+ """
115
+ if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices))
116
+ or (isinstance(response, dict) and not response.get("choices"))):
117
+ error_msg = ""
118
+ if hasattr(response, 'error') and response.error and isinstance(response.error, dict):
119
+ error_msg = response.error.get('message', '')
120
+ elif hasattr(response, 'msg'):
121
+ error_msg = response.msg
122
+
123
+ raise LLMResponseError(
124
+ error_msg if error_msg else "Unknown error",
125
+ self.model_name or "unknown",
126
+ response
127
+ )
128
+
129
+ return ModelResponse.from_openai_response(response)
130
+
131
+ def postprocess_stream_response(self, chunk: Any) -> ModelResponse:
132
+ """Process OpenAI streaming response chunk.
133
+
134
+ Args:
135
+ chunk: OpenAI response chunk.
136
+
137
+ Returns:
138
+ ModelResponse object.
139
+
140
+ Raises:
141
+ LLMResponseError: When LLM response error occurs.
142
+ """
143
+ # Check if chunk contains error
144
+ if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')):
145
+ error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error')
146
+ raise LLMResponseError(
147
+ error_msg,
148
+ self.model_name or "unknown",
149
+ chunk
150
+ )
151
+
152
+ # process tool calls
153
+ if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.tool_calls) or (
154
+ isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get("delta", {}).get("tool_calls")):
155
+ tool_calls = chunk.choices[0].delta.tool_calls if hasattr(chunk, 'choices') else chunk["choices"][0].get("delta", {}).get("tool_calls")
156
+
157
+ for tool_call in tool_calls:
158
+ index = tool_call.index if hasattr(tool_call, 'index') else tool_call["index"]
159
+ func_name = tool_call.function.name if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("name")
160
+ func_args = tool_call.function.arguments if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("arguments")
161
+ if index >= len(self.stream_tool_buffer):
162
+ self.stream_tool_buffer.append({
163
+ "id": tool_call.id if hasattr(tool_call, 'id') else tool_call.get("id"),
164
+ "type": "function",
165
+ "function": {
166
+ "name": func_name,
167
+ "arguments": func_args
168
+ }
169
+ })
170
+ else:
171
+ self.stream_tool_buffer[index]["function"]["arguments"] += func_args
172
+ processed_chunk = chunk
173
+ if hasattr(processed_chunk, 'choices'):
174
+ processed_chunk.choices[0].delta.tool_calls = None
175
+ else:
176
+ processed_chunk["choices"][0]["delta"]["tool_calls"] = None
177
+ resp = ModelResponse.from_openai_stream_chunk(processed_chunk)
178
+ if (not resp.content and not resp.usage.get("total_tokens", 0)):
179
+ return None
180
+ if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].finish_reason) or (
181
+ isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get(
182
+ "finish_reason")):
183
+ finish_reason = chunk.choices[0].finish_reason if hasattr(chunk, 'choices') else chunk["choices"][0].get(
184
+ "finish_reason")
185
+ if self.stream_tool_buffer:
186
+ tool_call_chunk = {
187
+ "id": chunk.id if hasattr(chunk, 'id') else chunk.get("id"),
188
+ "model": chunk.model if hasattr(chunk, 'model') else chunk.get("model"),
189
+ "object": chunk.object if hasattr(chunk, 'object') else chunk.get("object"),
190
+ "choices": [
191
+ {
192
+ "delta": {
193
+ "role": "assistant",
194
+ "content": "",
195
+ "tool_calls": self.stream_tool_buffer
196
+ }
197
+ }
198
+ ]
199
+ }
200
+ self.stream_tool_buffer = []
201
+ return ModelResponse.from_openai_stream_chunk(tool_call_chunk)
202
+
203
+ return ModelResponse.from_openai_stream_chunk(chunk)
204
+
205
+ def completion(self,
206
+ messages: List[Dict[str, str]],
207
+ temperature: float = 0.0,
208
+ max_tokens: int = None,
209
+ stop: List[str] = None,
210
+ **kwargs) -> ModelResponse:
211
+ """Synchronously call OpenAI to generate response.
212
+
213
+ Args:
214
+ messages: Message list.
215
+ temperature: Temperature parameter.
216
+ max_tokens: Maximum number of tokens to generate.
217
+ stop: List of stop sequences.
218
+ **kwargs: Other parameters.
219
+
220
+ Returns:
221
+ ModelResponse object.
222
+
223
+ Raises:
224
+ LLMResponseError: When LLM response error occurs.
225
+ """
226
+ if not self.provider:
227
+ raise RuntimeError(
228
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
229
+
230
+ processed_messages = self.preprocess_messages(messages)
231
+
232
+ try:
233
+ openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
234
+ if self.is_http_provider:
235
+ response = self.http_provider.sync_call(openai_params)
236
+ else:
237
+ response = self.provider.chat.completions.create(**openai_params)
238
+
239
+ if (hasattr(response, 'code') and response.code != 0) or (
240
+ isinstance(response, dict) and response.get("code", 0) != 0):
241
+ error_msg = getattr(response, 'msg', 'Unknown error')
242
+ logger.warn(f"API Error: {error_msg}")
243
+ raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response)
244
+
245
+ if not response:
246
+ raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown"))
247
+
248
+ resp = self.postprocess_response(response)
249
+ usage_process(resp.usage)
250
+ return resp
251
+ except Exception as e:
252
+ if isinstance(e, LLMResponseError):
253
+ raise e
254
+ logger.warn(f"Error in OpenAI completion: {e}")
255
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
256
+
257
+ def stream_completion(self,
258
+ messages: List[Dict[str, str]],
259
+ temperature: float = 0.0,
260
+ max_tokens: int = None,
261
+ stop: List[str] = None,
262
+ **kwargs) -> Generator[ModelResponse, None, None]:
263
+ """Synchronously call OpenAI to generate streaming response.
264
+
265
+ Args:
266
+ messages: Message list.
267
+ temperature: Temperature parameter.
268
+ max_tokens: Maximum number of tokens to generate.
269
+ stop: List of stop sequences.
270
+ **kwargs: Other parameters.
271
+
272
+ Returns:
273
+ Generator yielding ModelResponse chunks.
274
+
275
+ Raises:
276
+ LLMResponseError: When LLM response error occurs.
277
+ """
278
+ if not self.provider:
279
+ raise RuntimeError(
280
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
281
+
282
+ processed_messages = self.preprocess_messages(messages)
283
+ usage={
284
+ "completion_tokens": 0,
285
+ "prompt_tokens": 0,
286
+ "total_tokens": 0
287
+ }
288
+
289
+ try:
290
+ openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
291
+ openai_params["stream"] = True
292
+ if self.is_http_provider:
293
+ response_stream = self.http_provider.sync_stream_call(openai_params)
294
+ else:
295
+ response_stream = self.provider.chat.completions.create(**openai_params)
296
+
297
+ for chunk in response_stream:
298
+ if not chunk:
299
+ continue
300
+ resp = self.postprocess_stream_response(chunk)
301
+ if resp:
302
+ self._accumulate_chunk_usage(usage, resp.usage)
303
+ yield resp
304
+ usage_process(usage)
305
+
306
+ except Exception as e:
307
+ logger.warn(f"Error in stream_completion: {e}")
308
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
309
+
310
+ async def astream_completion(self,
311
+ messages: List[Dict[str, str]],
312
+ temperature: float = 0.0,
313
+ max_tokens: int = None,
314
+ stop: List[str] = None,
315
+ **kwargs) -> AsyncGenerator[ModelResponse, None]:
316
+ """Asynchronously call OpenAI to generate streaming response.
317
+
318
+ Args:
319
+ messages: Message list.
320
+ temperature: Temperature parameter.
321
+ max_tokens: Maximum number of tokens to generate.
322
+ stop: List of stop sequences.
323
+ **kwargs: Other parameters.
324
+
325
+ Returns:
326
+ AsyncGenerator yielding ModelResponse chunks.
327
+
328
+ Raises:
329
+ LLMResponseError: When LLM response error occurs.
330
+ """
331
+ if not self.async_provider:
332
+ raise RuntimeError(
333
+ "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
334
+
335
+ processed_messages = self.preprocess_messages(messages)
336
+ usage = {
337
+ "completion_tokens": 0,
338
+ "prompt_tokens": 0,
339
+ "total_tokens": 0
340
+ }
341
+
342
+ try:
343
+ openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
344
+ openai_params["stream"] = True
345
+
346
+ if self.is_http_provider:
347
+ async for chunk in self.http_provider.async_stream_call(openai_params):
348
+ if not chunk:
349
+ continue
350
+ resp = self.postprocess_stream_response(chunk)
351
+ self._accumulate_chunk_usage(usage, resp.usage)
352
+ yield resp
353
+ else:
354
+ response_stream = await self.async_provider.chat.completions.create(**openai_params)
355
+ async for chunk in response_stream:
356
+ if not chunk:
357
+ continue
358
+ resp = self.postprocess_stream_response(chunk)
359
+ if resp:
360
+ self._accumulate_chunk_usage(usage, resp.usage)
361
+ yield resp
362
+ usage_process(usage)
363
+
364
+ except Exception as e:
365
+ logger.warn(f"Error in astream_completion: {e}")
366
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
367
+
368
+ async def acompletion(self,
369
+ messages: List[Dict[str, str]],
370
+ temperature: float = 0.0,
371
+ max_tokens: int = None,
372
+ stop: List[str] = None,
373
+ **kwargs) -> ModelResponse:
374
+ """Asynchronously call OpenAI to generate response.
375
+
376
+ Args:
377
+ messages: Message list.
378
+ temperature: Temperature parameter.
379
+ max_tokens: Maximum number of tokens to generate.
380
+ stop: List of stop sequences.
381
+ **kwargs: Other parameters.
382
+
383
+ Returns:
384
+ ModelResponse object.
385
+
386
+ Raises:
387
+ LLMResponseError: When LLM response error occurs.
388
+ """
389
+ if not self.async_provider:
390
+ raise RuntimeError(
391
+ "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
392
+
393
+ processed_messages = self.preprocess_messages(messages)
394
+
395
+ try:
396
+ openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
397
+ if self.is_http_provider:
398
+ response = await self.http_provider.async_call(openai_params)
399
+ else:
400
+ response = await self.async_provider.chat.completions.create(**openai_params)
401
+
402
+ if (hasattr(response, 'code') and response.code != 0) or (
403
+ isinstance(response, dict) and response.get("code", 0) != 0):
404
+ error_msg = getattr(response, 'msg', 'Unknown error')
405
+ logger.warn(f"API Error: {error_msg}")
406
+ raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response)
407
+
408
+ if not response:
409
+ raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown"))
410
+
411
+ resp = self.postprocess_response(response)
412
+ usage_process(resp.usage)
413
+ return resp
414
+ except Exception as e:
415
+ if isinstance(e, LLMResponseError):
416
+ raise e
417
+ logger.warn(f"Error in acompletion: {e}")
418
+ raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
419
+
420
+ def get_openai_params(self,
421
+ messages: List[Dict[str, str]],
422
+ temperature: float = 0.0,
423
+ max_tokens: int = None,
424
+ stop: List[str] = None,
425
+ **kwargs) -> Dict[str, Any]:
426
+ openai_params = {
427
+ "model": kwargs.get("model_name", self.model_name or ""),
428
+ "messages": messages,
429
+ "temperature": temperature,
430
+ "max_tokens": max_tokens,
431
+ "stop": stop
432
+ }
433
+
434
+ supported_params = [
435
+ "max_completion_tokens", "meta_data", "modalities", "n", "parallel_tool_calls",
436
+ "prediction", "reasoning_effort", "service_tier", "stream_options", "web_search_options"
437
+ "frequency_penalty", "logit_bias", "logprobs", "top_logprobs",
438
+ "presence_penalty", "response_format", "seed", "stream", "top_p",
439
+ "user", "function_call", "functions", "tools", "tool_choice"
440
+ ]
441
+
442
+ for param in supported_params:
443
+ if param in kwargs:
444
+ openai_params[param] = kwargs[param]
445
+
446
+ return openai_params
447
+
448
+ def speech_to_text(self,
449
+ audio_file: str,
450
+ language: str = None,
451
+ prompt: str = None,
452
+ **kwargs) -> ModelResponse:
453
+ """Convert speech to text.
454
+
455
+ Uses OpenAI's speech-to-text API to convert audio files to text.
456
+
457
+ Args:
458
+ audio_file: Path to audio file or file object.
459
+ language: Audio language, optional.
460
+ prompt: Transcription prompt, optional.
461
+ **kwargs: Other parameters, may include:
462
+ - model: Transcription model name, defaults to "whisper-1".
463
+ - response_format: Response format, defaults to "text".
464
+ - temperature: Sampling temperature, defaults to 0.
465
+
466
+ Returns:
467
+ ModelResponse: Unified model response object, with content field containing the transcription result.
468
+
469
+ Raises:
470
+ LLMResponseError: When LLM response error occurs.
471
+ """
472
+ if not self.provider:
473
+ raise RuntimeError(
474
+ "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
475
+
476
+ try:
477
+ # Prepare parameters
478
+ transcription_params = {
479
+ "model": kwargs.get("model", "whisper-1"),
480
+ "response_format": kwargs.get("response_format", "text"),
481
+ "temperature": kwargs.get("temperature", 0)
482
+ }
483
+
484
+ # Add optional parameters
485
+ if language:
486
+ transcription_params["language"] = language
487
+ if prompt:
488
+ transcription_params["prompt"] = prompt
489
+
490
+ # Open file (if path is provided)
491
+ if isinstance(audio_file, str):
492
+ with open(audio_file, "rb") as file:
493
+ transcription_response = self.provider.audio.transcriptions.create(
494
+ file=file,
495
+ **transcription_params
496
+ )
497
+ else:
498
+ # If already a file object
499
+ transcription_response = self.provider.audio.transcriptions.create(
500
+ file=audio_file,
501
+ **transcription_params
502
+ )
503
+
504
+ # Create ModelResponse
505
+ return ModelResponse(
506
+ id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}",
507
+ model=transcription_params["model"],
508
+ content=transcription_response.text if hasattr(transcription_response, 'text') else str(
509
+ transcription_response),
510
+ raw_response=transcription_response,
511
+ message={
512
+ "role": "assistant",
513
+ "content": transcription_response.text if hasattr(transcription_response, 'text') else str(
514
+ transcription_response)
515
+ }
516
+ )
517
+ except Exception as e:
518
+ logger.warn(f"Speech-to-text error: {e}")
519
+ raise LLMResponseError(str(e), kwargs.get("model", "whisper-1"))
520
+
521
+ async def aspeech_to_text(self,
522
+ audio_file: str,
523
+ language: str = None,
524
+ prompt: str = None,
525
+ **kwargs) -> ModelResponse:
526
+ """Asynchronously convert speech to text.
527
+
528
+ Uses OpenAI's speech-to-text API to convert audio files to text.
529
+
530
+ Args:
531
+ audio_file: Path to audio file or file object.
532
+ language: Audio language, optional.
533
+ prompt: Transcription prompt, optional.
534
+ **kwargs: Other parameters, may include:
535
+ - model: Transcription model name, defaults to "whisper-1".
536
+ - response_format: Response format, defaults to "text".
537
+ - temperature: Sampling temperature, defaults to 0.
538
+
539
+ Returns:
540
+ ModelResponse: Unified model response object, with content field containing the transcription result.
541
+
542
+ Raises:
543
+ LLMResponseError: When LLM response error occurs.
544
+ """
545
+ if not self.async_provider:
546
+ raise RuntimeError(
547
+ "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
548
+
549
+ try:
550
+ # Prepare parameters
551
+ transcription_params = {
552
+ "model": kwargs.get("model", "whisper-1"),
553
+ "response_format": kwargs.get("response_format", "text"),
554
+ "temperature": kwargs.get("temperature", 0)
555
+ }
556
+
557
+ # Add optional parameters
558
+ if language:
559
+ transcription_params["language"] = language
560
+ if prompt:
561
+ transcription_params["prompt"] = prompt
562
+
563
+ # Open file (if path is provided)
564
+ if isinstance(audio_file, str):
565
+ with open(audio_file, "rb") as file:
566
+ transcription_response = await self.async_provider.audio.transcriptions.create(
567
+ file=file,
568
+ **transcription_params
569
+ )
570
+ else:
571
+ # If already a file object
572
+ transcription_response = await self.async_provider.audio.transcriptions.create(
573
+ file=audio_file,
574
+ **transcription_params
575
+ )
576
+
577
+ # Create ModelResponse
578
+ return ModelResponse(
579
+ id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}",
580
+ model=transcription_params["model"],
581
+ content=transcription_response.text if hasattr(transcription_response, 'text') else str(
582
+ transcription_response),
583
+ raw_response=transcription_response,
584
+ message={
585
+ "role": "assistant",
586
+ "content": transcription_response.text if hasattr(transcription_response, 'text') else str(
587
+ transcription_response)
588
+ }
589
+ )
590
+ except Exception as e:
591
+ logger.warn(f"Async speech-to-text error: {e}")
592
+ raise LLMResponseError(str(e), kwargs.get("model", "whisper-1"))
593
+
594
+
595
+ class AzureOpenAIProvider(OpenAIProvider):
596
+ """Azure OpenAI provider implementation.
597
+ """
598
+
599
+ def _init_provider(self):
600
+ """Initialize Azure OpenAI provider.
601
+
602
+ Returns:
603
+ Azure OpenAI provider instance.
604
+ """
605
+ from langchain_openai import AzureChatOpenAI
606
+
607
+ # Get API key
608
+ api_key = self.api_key
609
+ if not api_key:
610
+ env_var = "AZURE_OPENAI_API_KEY"
611
+ api_key = os.getenv(env_var, "")
612
+ if not api_key:
613
+ raise ValueError(
614
+ f"Azure OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters")
615
+
616
+ # Get API version
617
+ api_version = self.kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
618
+
619
+ # Get endpoint
620
+ azure_endpoint = self.base_url
621
+ if not azure_endpoint:
622
+ azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "")
623
+ if not azure_endpoint:
624
+ raise ValueError(
625
+ "Azure OpenAI endpoint not found, please set AZURE_OPENAI_ENDPOINT environment variable or provide it in the parameters")
626
+
627
+ return AzureChatOpenAI(
628
+ model=self.model_name or "gpt-4o",
629
+ temperature=self.kwargs.get("temperature", 0.0),
630
+ api_version=api_version,
631
+ azure_endpoint=azure_endpoint,
632
+ api_key=api_key
633
+ )
aworld/models/openai_tokenizer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 AWorld Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tokenization classes for OpenAI models."""
16
+
17
+ import base64
18
+ import unicodedata
19
+ from pathlib import Path
20
+ from typing import Collection, Dict, List, Set, Union
21
+ from aworld.logs.util import logger
22
+ from aworld.utils import import_package
23
+ import_package("tiktoken")
24
+ import tiktoken
25
+
26
+ VOCAB_FILES_NAMES = {'vocab_file': 'cl100k_base.tiktoken'}
27
+
28
+ # OpenAI GPT tokenizer pattern
29
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
30
+
31
+ # OpenAI special tokens
32
+ ENDOFTEXT = '<|endoftext|>'
33
+ SPECIAL_TOKENS = {
34
+ ENDOFTEXT: 100256,
35
+ }
36
+
37
+
38
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
39
+ """Load tiktoken BPE file similar to qwen_tokenizer."""
40
+ with open(tiktoken_bpe_file, 'rb') as f:
41
+ contents = f.read()
42
+ return {
43
+ base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
44
+ }
45
+
46
+
47
+ class OpenAITokenizer:
48
+ """OpenAI tokenizer using local tiktoken file."""
49
+
50
+ vocab_files_names = VOCAB_FILES_NAMES
51
+
52
+ def __init__(
53
+ self,
54
+ vocab_file=None,
55
+ errors='replace',
56
+ extra_vocab_file=None,
57
+ ):
58
+ if not vocab_file:
59
+ vocab_file = VOCAB_FILES_NAMES['vocab_file']
60
+ self._decode_use_source_tokenizer = False
61
+
62
+ # how to handle errors in decoding UTF-8 byte sequences
63
+ # use ignore if you are in streaming inference
64
+ self.errors = errors
65
+
66
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
67
+ self.special_tokens = SPECIAL_TOKENS.copy()
68
+
69
+ # try load extra vocab from file
70
+ if extra_vocab_file is not None:
71
+ used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
72
+ extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
73
+ for token, index in extra_mergeable_ranks.items():
74
+ if token in self.mergeable_ranks:
75
+ logger.info(f'extra token {token} exists, skipping')
76
+ continue
77
+ if index in used_ids:
78
+ logger.info(f'the index {index} for extra token {token} exists, skipping')
79
+ continue
80
+ self.mergeable_ranks[token] = index
81
+ # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
82
+
83
+ enc = tiktoken.Encoding(
84
+ 'cl100k_base',
85
+ pat_str=PAT_STR,
86
+ mergeable_ranks=self.mergeable_ranks,
87
+ special_tokens=self.special_tokens,
88
+ )
89
+ assert len(self.mergeable_ranks) + len(
90
+ self.special_tokens
91
+ ) == enc.n_vocab, f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding'
92
+
93
+ self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
94
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
95
+
96
+ self.tokenizer = enc # type: tiktoken.Encoding
97
+
98
+ self.eod_id = self.special_tokens[ENDOFTEXT]
99
+
100
+ def __getstate__(self):
101
+ # for pickle lovers
102
+ state = self.__dict__.copy()
103
+ del state['tokenizer']
104
+ return state
105
+
106
+ def __setstate__(self, state):
107
+ # tokenizer is not python native; don't pass it; rebuild it
108
+ self.__dict__.update(state)
109
+ enc = tiktoken.Encoding(
110
+ 'cl100k_base',
111
+ pat_str=PAT_STR,
112
+ mergeable_ranks=self.mergeable_ranks,
113
+ special_tokens=self.special_tokens,
114
+ )
115
+ self.tokenizer = enc
116
+
117
+ def __len__(self) -> int:
118
+ return self.tokenizer.n_vocab
119
+
120
+ def get_vocab(self) -> Dict[bytes, int]:
121
+ return self.mergeable_ranks
122
+
123
+ def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
124
+ ids = []
125
+ if isinstance(tokens, (str, bytes)):
126
+ if tokens in self.special_tokens:
127
+ return self.special_tokens[tokens]
128
+ else:
129
+ return self.mergeable_ranks.get(tokens)
130
+ for token in tokens:
131
+ if token in self.special_tokens:
132
+ ids.append(self.special_tokens[token])
133
+ else:
134
+ ids.append(self.mergeable_ranks.get(token))
135
+ return ids
136
+
137
+ def tokenize(
138
+ self,
139
+ text: str,
140
+ allowed_special: Union[Set, str] = 'all',
141
+ disallowed_special: Union[Collection, str] = (),
142
+ ) -> List[Union[bytes, str]]:
143
+ """
144
+ Converts a string in a sequence of tokens.
145
+
146
+ Args:
147
+ text (`str`):
148
+ The sequence to be encoded.
149
+ allowed_special (`Literal["all"]` or `set`):
150
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
151
+ Default to "all".
152
+ disallowed_special (`Literal["all"]` or `Collection`):
153
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
154
+ Default to an empty tuple.
155
+
156
+ Returns:
157
+ `List[bytes|str]`: The list of tokens.
158
+ """
159
+ tokens = []
160
+ if text is None:
161
+ return tokens
162
+ text = unicodedata.normalize('NFC', text)
163
+
164
+ # this implementation takes a detour: text -> token id -> token surface forms
165
+ for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
166
+ tokens.append(self.decoder[t])
167
+ return tokens
168
+
169
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
170
+ """
171
+ Converts a sequence of tokens in a single string.
172
+ """
173
+ text = ''
174
+ temp = b''
175
+ for t in tokens:
176
+ if isinstance(t, str):
177
+ if temp:
178
+ text += temp.decode('utf-8', errors=self.errors)
179
+ temp = b''
180
+ text += t
181
+ elif isinstance(t, bytes):
182
+ temp += t
183
+ else:
184
+ raise TypeError('token should only be of type types or str')
185
+ if temp:
186
+ text += temp.decode('utf-8', errors=self.errors)
187
+ return text
188
+
189
+ @property
190
+ def vocab_size(self):
191
+ return self.tokenizer.n_vocab
192
+
193
+ def _decode(
194
+ self,
195
+ token_ids: Union[int, List[int]],
196
+ skip_special_tokens: bool = False,
197
+ errors: str = None,
198
+ ) -> str:
199
+ if isinstance(token_ids, int):
200
+ token_ids = [token_ids]
201
+ if skip_special_tokens:
202
+ token_ids = [i for i in token_ids if i < self.eod_id]
203
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
204
+
205
+ def encode(self, text: str) -> List[int]:
206
+ return self.tokenizer.encode(text)
207
+
208
+ def decode(self, token_ids: Union[int, List[int]], errors: str = None) -> str:
209
+ return self._decode(token_ids, errors=errors)
210
+
211
+ def count_tokens(self, text: str) -> int:
212
+ return len(self.encode(text))
213
+
214
+ def truncate(self, text: str, max_token: int, start_token: int = 0, keep_both_sides: bool = False) -> str:
215
+ max_token = int(max_token)
216
+ token_ids = self.encode(text)[start_token:]
217
+ if len(token_ids) <= max_token:
218
+ return self.decode(token_ids)
219
+
220
+ if keep_both_sides:
221
+ ellipsis_tokens = self.encode("...")
222
+ ellipsis_len = len(ellipsis_tokens)
223
+ available = max_token - ellipsis_len
224
+ if available <= 0: # Degenerate case: not enough space even for "..."
225
+ return self.decode(token_ids[:max_token])
226
+
227
+ left_len = available // 2
228
+ right_len = available - left_len
229
+ token_ids = token_ids[:left_len] + ellipsis_tokens + token_ids[-right_len:]
230
+ else:
231
+ token_ids = token_ids[:max_token]
232
+
233
+ return self.decode(token_ids)
234
+
235
+
236
+ # Default tokenizer instance using local cl100k_base.tiktoken
237
+ openai_tokenizer = OpenAITokenizer(Path(__file__).resolve().parent.parent / 'config' / 'cl100k_base.tiktoken')
aworld/models/qwen_tokenizer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tokenization classes for QWen."""
16
+
17
+ import base64
18
+ import unicodedata
19
+ from pathlib import Path
20
+ from typing import Collection, Dict, List, Set, Union
21
+ from aworld.logs.util import logger
22
+ from aworld.utils import import_package
23
+ import_package("tiktoken")
24
+ import tiktoken
25
+
26
+ VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'}
27
+
28
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
29
+ ENDOFTEXT = '<|endoftext|>'
30
+ IMSTART = '<|im_start|>'
31
+ IMEND = '<|im_end|>'
32
+ # as the default behavior is changed to allow special tokens in
33
+ # regular texts, the surface forms of special tokens need to be
34
+ # as different as possible to minimize the impact
35
+ EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205)))
36
+ # changed to use actual index to avoid misconfiguration with vocabulary expansion
37
+ SPECIAL_START_ID = 151643
38
+ SPECIAL_TOKENS = tuple(enumerate(
39
+ ((
40
+ ENDOFTEXT,
41
+ IMSTART,
42
+ IMEND,
43
+ ) + EXTRAS),
44
+ start=SPECIAL_START_ID,
45
+ ))
46
+ SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
47
+
48
+
49
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
50
+ with open(tiktoken_bpe_file, 'rb') as f:
51
+ contents = f.read()
52
+ return {
53
+ base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
54
+ }
55
+
56
+
57
+ class QWenTokenizer:
58
+ """QWen tokenizer."""
59
+
60
+ vocab_files_names = VOCAB_FILES_NAMES
61
+
62
+ def __init__(
63
+ self,
64
+ vocab_file=None,
65
+ errors='replace',
66
+ extra_vocab_file=None,
67
+ ):
68
+ if not vocab_file:
69
+ vocab_file = VOCAB_FILES_NAMES['vocab_file']
70
+ self._decode_use_source_tokenizer = False
71
+
72
+ # how to handle errors in decoding UTF-8 byte sequences
73
+ # use ignore if you are in streaming inference
74
+ self.errors = errors
75
+
76
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
77
+ self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
78
+
79
+ # try load extra vocab from file
80
+ if extra_vocab_file is not None:
81
+ used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
82
+ extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
83
+ for token, index in extra_mergeable_ranks.items():
84
+ if token in self.mergeable_ranks:
85
+ logger.info(f'extra token {token} exists, skipping')
86
+ continue
87
+ if index in used_ids:
88
+ logger.info(f'the index {index} for extra token {token} exists, skipping')
89
+ continue
90
+ self.mergeable_ranks[token] = index
91
+ # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
92
+
93
+ enc = tiktoken.Encoding(
94
+ 'Qwen',
95
+ pat_str=PAT_STR,
96
+ mergeable_ranks=self.mergeable_ranks,
97
+ special_tokens=self.special_tokens,
98
+ )
99
+ assert len(self.mergeable_ranks) + len(
100
+ self.special_tokens
101
+ ) == enc.n_vocab, f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding'
102
+
103
+ self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
104
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
105
+
106
+ self.tokenizer = enc # type: tiktoken.Encoding
107
+
108
+ self.eod_id = self.tokenizer.eot_token
109
+ self.im_start_id = self.special_tokens[IMSTART]
110
+ self.im_end_id = self.special_tokens[IMEND]
111
+
112
+ def __getstate__(self):
113
+ # for pickle lovers
114
+ state = self.__dict__.copy()
115
+ del state['tokenizer']
116
+ return state
117
+
118
+ def __setstate__(self, state):
119
+ # tokenizer is not python native; don't pass it; rebuild it
120
+ self.__dict__.update(state)
121
+ enc = tiktoken.Encoding(
122
+ 'Qwen',
123
+ pat_str=PAT_STR,
124
+ mergeable_ranks=self.mergeable_ranks,
125
+ special_tokens=self.special_tokens,
126
+ )
127
+ self.tokenizer = enc
128
+
129
+ def __len__(self) -> int:
130
+ return self.tokenizer.n_vocab
131
+
132
+ def get_vocab(self) -> Dict[bytes, int]:
133
+ return self.mergeable_ranks
134
+
135
+ def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
136
+ ids = []
137
+ if isinstance(tokens, (str, bytes)):
138
+ if tokens in self.special_tokens:
139
+ return self.special_tokens[tokens]
140
+ else:
141
+ return self.mergeable_ranks.get(tokens)
142
+ for token in tokens:
143
+ if token in self.special_tokens:
144
+ ids.append(self.special_tokens[token])
145
+ else:
146
+ ids.append(self.mergeable_ranks.get(token))
147
+ return ids
148
+
149
+ def tokenize(
150
+ self,
151
+ text: str,
152
+ allowed_special: Union[Set, str] = 'all',
153
+ disallowed_special: Union[Collection, str] = (),
154
+ ) -> List[Union[bytes, str]]:
155
+ """
156
+ Converts a string in a sequence of tokens.
157
+
158
+ Args:
159
+ text (`str`):
160
+ The sequence to be encoded.
161
+ allowed_special (`Literal["all"]` or `set`):
162
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
163
+ Default to "all".
164
+ disallowed_special (`Literal["all"]` or `Collection`):
165
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
166
+ Default to an empty tuple.
167
+
168
+ Returns:
169
+ `List[bytes|str]`: The list of tokens.
170
+ """
171
+ tokens = []
172
+ if text is None:
173
+ return tokens
174
+ text = unicodedata.normalize('NFC', text)
175
+
176
+ # this implementation takes a detour: text -> token id -> token surface forms
177
+ for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
178
+ tokens.append(self.decoder[t])
179
+ return tokens
180
+
181
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
182
+ """
183
+ Converts a sequence of tokens in a single string.
184
+ """
185
+ text = ''
186
+ temp = b''
187
+ for t in tokens:
188
+ if isinstance(t, str):
189
+ if temp:
190
+ text += temp.decode('utf-8', errors=self.errors)
191
+ temp = b''
192
+ text += t
193
+ elif isinstance(t, bytes):
194
+ temp += t
195
+ else:
196
+ raise TypeError('token should only be of type types or str')
197
+ if temp:
198
+ text += temp.decode('utf-8', errors=self.errors)
199
+ return text
200
+
201
+ @property
202
+ def vocab_size(self):
203
+ return self.tokenizer.n_vocab
204
+
205
+ def _decode(
206
+ self,
207
+ token_ids: Union[int, List[int]],
208
+ skip_special_tokens: bool = False,
209
+ errors: str = None,
210
+ ) -> str:
211
+ if isinstance(token_ids, int):
212
+ token_ids = [token_ids]
213
+ if skip_special_tokens:
214
+ token_ids = [i for i in token_ids if i < self.eod_id]
215
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
216
+
217
+ def encode(self, text: str) -> List[int]:
218
+ return self.convert_tokens_to_ids(self.tokenize(text))
219
+
220
+ def count_tokens(self, text: str) -> int:
221
+ return len(self.tokenize(text))
222
+
223
+ def truncate(self, text: str, max_token: int, start_token: int = 0, keep_both_sides: bool = False) -> str:
224
+ max_token = int(max_token)
225
+ token_list = self.tokenize(text)[start_token:]
226
+ if len(token_list) <= max_token:
227
+ return self.convert_tokens_to_string(token_list)
228
+
229
+ if keep_both_sides:
230
+ ellipsis_tokens = self.tokenize("...")
231
+ ellipsis_len = len(ellipsis_tokens)
232
+ available = max_token - ellipsis_len
233
+ if available <= 0: # Degenerate case: not enough space even for "..."
234
+ return self.convert_tokens_to_string(token_list[:max_token])
235
+
236
+ left_len = available // 2
237
+ right_len = available - left_len
238
+ token_list = token_list[:left_len] + ellipsis_tokens + token_list[-right_len:]
239
+ else:
240
+ token_list = token_list[:max_token]
241
+
242
+ return self.convert_tokens_to_string(token_list)
243
+
244
+
245
+ qwen_tokenizer = QWenTokenizer(Path(__file__).resolve().parent.parent / 'config' / 'qwen.tiktoken')
aworld/models/utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import copy
4
+ import inspect
5
+ import os.path
6
+ from typing import Dict, Any, List, Union
7
+
8
+ from aworld.core.context.base import Context
9
+ from aworld.logs.util import logger
10
+ from aworld.models.qwen_tokenizer import qwen_tokenizer
11
+ from aworld.models.openai_tokenizer import openai_tokenizer
12
+ from aworld.utils import import_package
13
+
14
+
15
+ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Context = None):
16
+ if not context:
17
+ context = Context.instance()
18
+
19
+ stacks = inspect.stack()
20
+ index = 0
21
+ for idx, stack in enumerate(stacks):
22
+ index = idx + 1
23
+ file = os.path.basename(stack.filename)
24
+ # supported use `llm.py` utility function only
25
+ if 'call_llm_model' in stack.function and file == 'llm.py':
26
+ break
27
+
28
+ if index >= len(stacks):
29
+ logger.warning("not category usage find to count")
30
+ else:
31
+ instance = stacks[index].frame.f_locals.get('self')
32
+ name = getattr(instance, "_name", "unknown")
33
+ usage[name] = copy.copy(usage)
34
+ # total usage
35
+ context.add_token(usage)
36
+
37
+
38
+ def num_tokens_from_messages(messages, model="gpt-4o"):
39
+ """Return the number of tokens used by a list of messages."""
40
+ import_package("tiktoken")
41
+ import tiktoken
42
+
43
+ if model.lower() == "qwen":
44
+ encoding = qwen_tokenizer
45
+ elif model.lower() == "openai":
46
+ encoding = openai_tokenizer
47
+ else:
48
+ try:
49
+ encoding = tiktoken.encoding_for_model(model)
50
+ except KeyError:
51
+ logger.warning(f"{model} model not found. Using cl100k_base encoding.")
52
+ encoding = tiktoken.get_encoding("cl100k_base")
53
+
54
+ tokens_per_message = 3
55
+ tokens_per_name = 1
56
+
57
+ num_tokens = 0
58
+ for message in messages:
59
+ num_tokens += tokens_per_message
60
+ if isinstance(message, str):
61
+ num_tokens += len(encoding.encode(message))
62
+ else:
63
+ for key, value in message.items():
64
+ num_tokens += len(encoding.encode(str(value)))
65
+ if key == "name":
66
+ num_tokens += tokens_per_name
67
+ num_tokens += 3
68
+ return num_tokens
69
+
70
+ def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"):
71
+ import_package("tiktoken")
72
+ import tiktoken
73
+
74
+ if model.lower() == "qwen":
75
+ return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides)
76
+ elif model.lower() == "openai":
77
+ return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides)
78
+
79
+ try:
80
+ encoding = tiktoken.encoding_for_model(model)
81
+ except KeyError:
82
+ logger.warning(f"{model} model not found. Using cl100k_base encoding.")
83
+ encoding = tiktoken.get_encoding("cl100k_base")
84
+
85
+ return encoding.truncate(messages, max_tokens, keep_both_sides)
86
+
87
+ def agent_desc_transform(agent_dict: Dict[str, Any],
88
+ agents: List[str] = None,
89
+ provider: str = 'openai',
90
+ strategy: str = 'min') -> List[Dict[str, Any]]:
91
+ """Default implement transform framework standard protocol to openai protocol of agent description.
92
+
93
+ Args:
94
+ agent_dict: Dict of descriptions of agents that are registered in the agent factory.
95
+ agents: Description of special agents to use.
96
+ provider: Different descriptions formats need to be processed based on the provider.
97
+ strategy: The value is `min` or `max`, when no special agents are provided, `min` indicates no content returned,
98
+ `max` means get all agents' descriptions.
99
+ """
100
+ agent_as_tools = []
101
+ if not agents and strategy == 'min':
102
+ return agent_as_tools
103
+
104
+ if provider and 'openai' in provider:
105
+ for agent_name, agent_info in agent_dict.items():
106
+ if agents and agent_name not in agents:
107
+ logger.debug(f"{agent_name} can not supported in {agents}, you can set `tools` params to support it.")
108
+ continue
109
+
110
+ for action in agent_info["abilities"]:
111
+ # Build parameter properties
112
+ properties = {}
113
+ required = []
114
+ for param_name, param_info in action["params"].items():
115
+ properties[param_name] = {
116
+ "description": param_info["desc"],
117
+ "type": param_info["type"] if param_info["type"] != "str" else "string"
118
+ }
119
+ if param_info.get("required", False):
120
+ required.append(param_name)
121
+
122
+ openai_function_schema = {
123
+ "name": f'{agent_name}__{action["name"]}',
124
+ "description": action["desc"],
125
+ "parameters": {
126
+ "type": "object",
127
+ "properties": properties,
128
+ "required": required
129
+ }
130
+ }
131
+
132
+ agent_as_tools.append({
133
+ "type": "function",
134
+ "function": openai_function_schema
135
+ })
136
+ return agent_as_tools
137
+
138
+
139
+ def tool_desc_transform(tool_dict: Dict[str, Any],
140
+ tools: List[str] = None,
141
+ black_tool_actions: Dict[str, List[str]] = {},
142
+ provider: str = 'openai',
143
+ strategy: str = 'min') -> List[Dict[str, Any]]:
144
+ """Default implement transform framework standard protocol to openai protocol of tool description.
145
+
146
+ Args:
147
+ tool_dict: Dict of descriptions of tools that are registered in the agent factory.
148
+ tools: Description of special tools to use.
149
+ provider: Different descriptions formats need to be processed based on the provider.
150
+ strategy: The value is `min` or `max`, when no special tools are provided, `min` indicates no content returned,
151
+ `max` means get all tools' descriptions.
152
+ """
153
+ openai_tools = []
154
+ if not tools and strategy == 'min':
155
+ return openai_tools
156
+
157
+ if black_tool_actions is None:
158
+ black_tool_actions = {}
159
+
160
+ if provider and 'openai' in provider:
161
+ for tool_name, tool_info in tool_dict.items():
162
+ if tools and tool_name not in tools:
163
+ logger.debug(f"{tool_name} can not supported in {tools}, you can set `tools` params to support it.")
164
+ continue
165
+
166
+ black_actions = black_tool_actions.get(tool_name, [])
167
+ for action in tool_info["actions"]:
168
+ if action['name'] in black_actions:
169
+ continue
170
+ # Build parameter properties
171
+ properties = {}
172
+ required = []
173
+ for param_name, param_info in action["params"].items():
174
+ properties[param_name] = {
175
+ "description": param_info["desc"],
176
+ "type": param_info["type"] if param_info["type"] != "str" else "string"
177
+ }
178
+ if param_info.get("required", False):
179
+ required.append(param_name)
180
+
181
+ openai_function_schema = {
182
+ "name": f'{tool_name}__{action["name"]}',
183
+ "description": action["desc"],
184
+ "parameters": {
185
+ "type": "object",
186
+ "properties": properties,
187
+ "required": required
188
+ }
189
+ }
190
+
191
+ openai_tools.append({
192
+ "type": "function",
193
+ "function": openai_function_schema
194
+ })
195
+ return openai_tools