Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import ast | |
| import asyncio | |
| import datetime | |
| import html | |
| import json | |
| import os | |
| import time | |
| from typing import ( | |
| Any, | |
| List, | |
| Dict, | |
| Generator, | |
| AsyncGenerator, | |
| ) | |
| from binascii import b2a_hex | |
| from aworld.config.conf import ClientType | |
| from aworld.core.llm_provider_base import LLMProviderBase | |
| from aworld.models.llm_http_handler import LLMHTTPHandler | |
| from aworld.models.model_response import ModelResponse, LLMResponseError, ToolCall | |
| from aworld.logs.util import logger | |
| from aworld.utils import import_package | |
| from aworld.models.utils import usage_process | |
| MODEL_NAMES = { | |
| "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"], | |
| "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"], | |
| } | |
| # Custom JSON encoder to handle ToolCall and other special types | |
| class CustomJSONEncoder(json.JSONEncoder): | |
| """Custom JSON encoder to handle ToolCall objects and other special types.""" | |
| def default(self, obj): | |
| # Handle objects with to_dict method | |
| if hasattr(obj, 'to_dict') and callable(obj.to_dict): | |
| return obj.to_dict() | |
| # Handle objects with __dict__ attribute (most custom classes) | |
| if hasattr(obj, '__dict__'): | |
| return obj.__dict__ | |
| # Let the base class handle it (will raise TypeError if not serializable) | |
| return super().default(obj) | |
| class AntProvider(LLMProviderBase): | |
| """Ant provider implementation. | |
| """ | |
| def _init_provider(self): | |
| """Initialize Ant provider. | |
| Returns: | |
| Ant provider instance. | |
| """ | |
| import_package("Crypto", install_name="pycryptodome") | |
| # Get API key | |
| api_key = self.api_key | |
| if not api_key: | |
| env_var = "ANT_API_KEY" | |
| api_key = os.getenv(env_var, "") | |
| self.api_key = api_key | |
| if not api_key: | |
| raise ValueError( | |
| f"ANT API key not found, please set {env_var} environment variable or provide it in the parameters") | |
| if api_key and api_key.startswith("ak_info:"): | |
| ak_info_str = api_key[len("ak_info:"):] | |
| try: | |
| ak_info = json.loads(ak_info_str) | |
| for key, value in ak_info.items(): | |
| os.environ[key] = value | |
| if key == "ANT_API_KEY": | |
| api_key = value | |
| self.api_key = api_key | |
| except Exception as e: | |
| logger.warn(f"Invalid ANT API key startswith ak_info: {api_key}") | |
| self.stream_api_key = os.getenv("ANT_STREAM_API_KEY", "") | |
| base_url = self.base_url | |
| if not base_url: | |
| base_url = os.getenv("ANT_ENDPOINT", "https://zdfmng.alipay.com") | |
| self.base_url = base_url | |
| self.aes_key = os.getenv("ANT_AES_KEY", "") | |
| self.is_http_provider = True | |
| self.kwargs["client_type"] = ClientType.HTTP | |
| logger.info(f"Using HTTP provider for Ant") | |
| self.http_provider = LLMHTTPHandler( | |
| base_url=base_url, | |
| api_key=api_key, | |
| model_name=self.model_name, | |
| ) | |
| self.is_http_provider = True | |
| return self.http_provider | |
| def _init_async_provider(self): | |
| """Initialize async Ant provider. | |
| Returns: | |
| Async Ant provider instance. | |
| """ | |
| # Get API key | |
| if not self.provider: | |
| provider = self._init_provider() | |
| return provider | |
| def supported_models(cls) -> list[str]: | |
| return [""] | |
| def _aes_encrypt(self, data, key): | |
| """AES encryption function. If data is not a multiple of 16 [encrypted data must be a multiple of 16!], pad it to a multiple of 16. | |
| Args: | |
| key: Encryption key | |
| data: Data to encrypt | |
| Returns: | |
| Encrypted data | |
| """ | |
| from Crypto.Cipher import AES | |
| iv = "1234567890123456" | |
| cipher = AES.new(key.encode('utf-8'), AES.MODE_CBC, iv.encode('utf-8')) | |
| block_size = AES.block_size | |
| # Check if data is a multiple of 16, if not, pad with b'\0' | |
| if len(data) % block_size != 0: | |
| add = block_size - (len(data) % block_size) | |
| else: | |
| add = 0 | |
| data = data.encode('utf-8') + b'\0' * add | |
| encrypted = cipher.encrypt(data) | |
| result = b2a_hex(encrypted) | |
| return result.decode('utf-8') | |
| def _build_openai_params(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs) -> Dict[str, Any]: | |
| openai_params = { | |
| "model": kwargs.get("model_name", self.model_name or ""), | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "stop": stop | |
| } | |
| supported_params = [ | |
| "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", | |
| "presence_penalty", "response_format", "seed", "stream", "top_p", | |
| "user", "function_call", "functions", "tools", "tool_choice" | |
| ] | |
| for param in supported_params: | |
| if param in kwargs: | |
| openai_params[param] = kwargs[param] | |
| return openai_params | |
| def _build_claude_params(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs) -> Dict[str, Any]: | |
| claude_params = { | |
| "model": kwargs.get("model_name", self.model_name or ""), | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "stop": stop | |
| } | |
| supported_params = [ | |
| "top_p", "top_k", "reasoning_effort", "tools", "tool_choice" | |
| ] | |
| for param in supported_params: | |
| if param in kwargs: | |
| claude_params[param] = kwargs[param] | |
| return claude_params | |
| def _get_visit_info(self): | |
| visit_info = { | |
| "visitDomain": self.kwargs.get("ant_visit_domain") or os.getenv("ANT_VISIT_DOMAIN", "BU_general"), | |
| "visitBiz": self.kwargs.get("ant_visit_biz") or os.getenv("ANT_VISIT_BIZ", ""), | |
| "visitBizLine": self.kwargs.get("ant_visit_biz_line") or os.getenv("ANT_VISIT_BIZ_LINE", "") | |
| } | |
| if not visit_info["visitBiz"] or not visit_info["visitBizLine"]: | |
| return None | |
| return visit_info | |
| def _get_service_param(self, | |
| message_key: str, | |
| output_type: str = "request", | |
| messages: List[Dict[str, str]] = None, | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """Get service name from model name. | |
| Returns: | |
| Service name. | |
| """ | |
| if messages: | |
| for message in messages: | |
| if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]: | |
| if message["content"] is None: message["content"] = "" | |
| processed_tool_calls = [] | |
| for tool_call in message["tool_calls"]: | |
| if isinstance(tool_call, dict): | |
| processed_tool_calls.append(tool_call) | |
| elif isinstance(tool_call, ToolCall): | |
| processed_tool_calls.append(tool_call.to_dict()) | |
| message["tool_calls"] = processed_tool_calls | |
| query_conditions = { | |
| "messageKey": message_key, | |
| } | |
| param = {"cacheInterval": -1, } | |
| visit_info = self._get_visit_info() | |
| if not visit_info: | |
| raise LLMResponseError( | |
| f"AntProvider#Invalid visit_info, please set ANT_VISIT_BIZ and ANT_VISIT_BIZ_LINE environment variable or provide it in the parameters", | |
| self.model_name or "unknown" | |
| ) | |
| param.update(visit_info) | |
| if self.model_name.startswith("claude"): | |
| query_conditions.update(self._build_claude_params(messages, temperature, max_tokens, stop, **kwargs)) | |
| param.update({ | |
| "serviceName": "amazon_claude_chat_completions_dataview", | |
| "queryConditions": query_conditions, | |
| }) | |
| elif output_type == "pull": | |
| param.update({ | |
| "serviceName": "chatgpt_response_query_dataview", | |
| "queryConditions": query_conditions | |
| }) | |
| else: | |
| query_conditions = { | |
| "model": self.model_name, | |
| "n": "1", | |
| "api_key": self.api_key, | |
| "messageKey": message_key, | |
| "outputType": "PULL", | |
| "messages": messages, | |
| } | |
| query_conditions.update(self._build_openai_params(messages, temperature, max_tokens, stop, **kwargs)) | |
| param.update({ | |
| "serviceName": "asyn_chatgpt_prompts_completions_query_dataview", | |
| "queryConditions": query_conditions, | |
| }) | |
| return param | |
| def _gen_message_key(self): | |
| def _timestamp(): | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| return timestamp | |
| timestamp = _timestamp() | |
| message_key = "llm_call_%s" % (timestamp) | |
| return message_key | |
| def _build_request_data(self, param: Dict[str, Any]): | |
| param_data = json.dumps(param) | |
| encrypted_param_data = self._aes_encrypt(param_data, self.aes_key) | |
| post_data = {"encryptedParam": encrypted_param_data} | |
| return post_data | |
| def _build_chat_query_request_data(self, | |
| message_key: str, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs): | |
| param = self._get_service_param(message_key, "request", messages, temperature, max_tokens, stop, **kwargs) | |
| query_data = self._build_request_data(param) | |
| return query_data | |
| def _post_chat_query_request(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs): | |
| message_key = self._gen_message_key() | |
| post_data = self._build_chat_query_request_data(message_key, | |
| messages, | |
| model_name=self.model_name, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stop=stop, | |
| **kwargs) | |
| response = self.http_provider.sync_call(post_data, endpoint="commonQuery/queryData") | |
| return message_key, response | |
| def _valid_chat_result(self, body): | |
| if "data" not in body or not body["data"]: | |
| return False | |
| if "values" not in body["data"] or not body["data"]["values"]: | |
| return False | |
| if "response" not in body["data"]["values"] and "data" not in body["data"]["values"]: | |
| return False | |
| return True | |
| def _build_chat_pull_request_data(self, message_key): | |
| param = self._get_service_param(message_key, "pull") | |
| pull_data = self._build_request_data(param) | |
| return pull_data | |
| def _pull_chat_result(self, message_key, response: Dict[str, Any], timeout): | |
| if self.model_name.startswith("claude"): | |
| if self._valid_chat_result(response): | |
| x = response["data"]["values"]["data"] | |
| ast_str = ast.literal_eval("'" + x + "'") | |
| result = html.unescape(ast_str) | |
| data = json.loads(result) | |
| return data | |
| else: | |
| raise LLMResponseError( | |
| f"Invalid response from Ant API, response: {response}", | |
| self.model_name or "unknown" | |
| ) | |
| post_data = self._build_chat_pull_request_data(message_key) | |
| url = 'commonQuery/queryData' | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| # Start polling until valid result or timeout | |
| start_time = time.time() | |
| elapsed_time = 0 | |
| while elapsed_time < timeout: | |
| response = self.http_provider.sync_call(post_data, endpoint=url, headers=headers) | |
| logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}") | |
| # Check if valid result is received | |
| if self._valid_chat_result(response): | |
| x = response["data"]["values"]["response"] | |
| ast_str = ast.literal_eval("'" + x + "'") | |
| result = html.unescape(ast_str) | |
| data = json.loads(result) | |
| return data | |
| elif (not response.get("success")) or ("data" in response and response["data"]): | |
| err_code = response.get("data", {}).get("errorCode", "") | |
| err_msg = response.get("data", {}).get("errorMessage", "") | |
| if err_code or err_msg: | |
| raise LLMResponseError( | |
| f"Request failed: {response}", | |
| self.model_name or "unknown" | |
| ) | |
| # If no result, wait 1 second and query again | |
| time.sleep(1) | |
| elapsed_time = time.time() - start_time | |
| logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s") | |
| # Timeout handling | |
| raise LLMResponseError( | |
| f"Timeout after {timeout} seconds waiting for response from Ant API", | |
| self.model_name or "unknown" | |
| ) | |
| async def _async_pull_chat_result(self, message_key, response: Dict[str, Any], timeout): | |
| if self.model_name.startswith("claude"): | |
| if self._valid_chat_result(response): | |
| x = response["data"]["values"]["data"] | |
| ast_str = ast.literal_eval("'" + x + "'") | |
| result = html.unescape(ast_str) | |
| data = json.loads(result) | |
| return data | |
| elif (not response.get("success")) or ("data" in response and response["data"]): | |
| err_code = response.get("data", {}).get("errorCode", "") | |
| err_msg = response.get("data", {}).get("errorMessage", "") | |
| if err_code or err_msg: | |
| raise LLMResponseError( | |
| f"Request failed: {response}", | |
| self.model_name or "unknown" | |
| ) | |
| post_data = self._build_chat_pull_request_data(message_key) | |
| url = 'commonQuery/queryData' | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| # Start polling until valid result or timeout | |
| start_time = time.time() | |
| elapsed_time = 0 | |
| while elapsed_time < timeout: | |
| response = await self.http_provider.async_call(post_data, endpoint=url, headers=headers) | |
| logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}") | |
| # Check if valid result is received | |
| if self._valid_chat_result(response): | |
| x = response["data"]["values"]["response"] | |
| ast_str = ast.literal_eval("'" + x + "'") | |
| result = html.unescape(ast_str) | |
| data = json.loads(result) | |
| return data | |
| elif (not response.get("success")) or ("data" in response and response["data"]): | |
| err_code = response.get("data", {}).get("errorCode", "") | |
| err_msg = response.get("data", {}).get("errorMessage", "") | |
| if err_code or err_msg: | |
| raise LLMResponseError( | |
| f"Request failed: {response}", | |
| self.model_name or "unknown" | |
| ) | |
| # If no result, wait 1 second and query again | |
| await asyncio.sleep(1) | |
| elapsed_time = time.time() - start_time | |
| logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s") | |
| # Timeout handling | |
| raise LLMResponseError( | |
| f"Timeout after {timeout} seconds waiting for response from Ant API", | |
| self.model_name or "unknown" | |
| ) | |
| def _convert_completion_message(self, message: Dict[str, Any], is_finished: bool = False) -> ModelResponse: | |
| """Convert Ant completion message to OpenAI format. | |
| Args: | |
| message: Ant completion message. | |
| Returns: | |
| OpenAI format message. | |
| """ | |
| # Generate unique ID | |
| response_id = f"ant-{hash(str(message)) & 0xffffffff:08x}" | |
| # Get content | |
| content = message.get("completion", "") | |
| # Create message object | |
| message_dict = { | |
| "role": "assistant", | |
| "content": content, | |
| "is_chunk": True | |
| } | |
| # Keep original contextId and sessionId | |
| if "contextId" in message: | |
| message_dict["contextId"] = message["contextId"] | |
| if "sessionId" in message: | |
| message_dict["sessionId"] = message["sessionId"] | |
| usage = { | |
| "completion_tokens": message.get("completionToken", 0), | |
| "prompt_tokens": message.get("promptTokens", 0), | |
| "total_tokens": message.get("completionToken", 0) + message.get("promptTokens", 0) | |
| } | |
| # process tool calls | |
| tool_calls = message.get("toolCalls", []) | |
| for tool_call in tool_calls: | |
| index = tool_call.get("index", 0) | |
| name = tool_call.get("function", {}).get("name") | |
| arguments = tool_call.get("function", {}).get("arguments") | |
| if index >= len(self.stream_tool_buffer): | |
| self.stream_tool_buffer.append({ | |
| "id": tool_call.get("id"), | |
| "type": "function", | |
| "function": { | |
| "name": name, | |
| "arguments": arguments | |
| } | |
| }) | |
| else: | |
| self.stream_tool_buffer[index]["function"]["arguments"] += arguments | |
| if is_finished and self.stream_tool_buffer: | |
| message_dict["tool_calls"] = self.stream_tool_buffer.copy() | |
| processed_tool_calls = [] | |
| for tool_call in self.stream_tool_buffer: | |
| processed_tool_calls.append(ToolCall.from_dict(tool_call)) | |
| tool_resp = ModelResponse( | |
| id=response_id, | |
| model=self.model_name or "ant", | |
| content=content, | |
| tool_calls=processed_tool_calls, | |
| usage=usage, | |
| raw_response=message, | |
| message=message_dict | |
| ) | |
| self.stream_tool_buffer = [] | |
| return tool_resp | |
| # Build and return ModelResponse object directly | |
| return ModelResponse( | |
| id=response_id, | |
| model=self.model_name or "ant", | |
| content=content, | |
| tool_calls=None, # TODO: add tool calls | |
| usage=usage, | |
| raw_response=message, | |
| message=message_dict | |
| ) | |
| def preprocess_stream_call_message(self, messages: List[Dict[str, str]], ext_params: Dict[str, Any]) -> Dict[ | |
| str, str]: | |
| """Preprocess messages, use Ant format directly. | |
| Args: | |
| messages: Ant format message list. | |
| Returns: | |
| Processed message list. | |
| """ | |
| param = { | |
| "messages": messages, | |
| "sessionId": "TkQUldjzOgYSKyTrpor3TA==", | |
| "model": self.model_name, | |
| "needMemory": False, | |
| "stream": True, | |
| "contextId": "contextId_34555fd2d246447fa55a1a259445a427", | |
| "platform": "AWorld" | |
| } | |
| for k in ext_params.keys(): | |
| if k not in param: | |
| param[k] = ext_params[k] | |
| return param | |
| def postprocess_response(self, response: Any) -> ModelResponse: | |
| """Process Ant response. | |
| Args: | |
| response: Ant response object. | |
| Returns: | |
| ModelResponse object. | |
| Raises: | |
| LLMResponseError: When LLM response error occurs. | |
| """ | |
| if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices)) | |
| or (isinstance(response, dict) and not response.get("choices"))): | |
| error_msg = "" | |
| if hasattr(response, 'error') and response.error and isinstance(response.error, dict): | |
| error_msg = response.error.get('message', '') | |
| elif hasattr(response, 'msg'): | |
| error_msg = response.msg | |
| raise LLMResponseError( | |
| error_msg if error_msg else "Unknown error", | |
| self.model_name or "unknown", | |
| response | |
| ) | |
| return ModelResponse.from_openai_response(response) | |
| def postprocess_stream_response(self, chunk: Any) -> ModelResponse: | |
| """Process Ant stream response chunk. | |
| Args: | |
| chunk: Ant response chunk. | |
| Returns: | |
| ModelResponse object. | |
| Raises: | |
| LLMResponseError: When LLM response error occurs. | |
| """ | |
| # Check if chunk contains error | |
| if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')): | |
| error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error') | |
| raise LLMResponseError( | |
| error_msg, | |
| self.model_name or "unknown", | |
| chunk | |
| ) | |
| if isinstance(chunk, dict) and ('completion' in chunk): | |
| return self._convert_completion_message(chunk) | |
| # If chunk is already in OpenAI format, use standard processing method | |
| return ModelResponse.from_openai_stream_chunk(chunk) | |
| def completion(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs) -> ModelResponse: | |
| """Synchronously call Ant to generate response. | |
| Args: | |
| messages: Message list. | |
| temperature: Temperature parameter. | |
| max_tokens: Maximum number of tokens to generate. | |
| stop: List of stop sequences. | |
| **kwargs: Other parameters. | |
| Returns: | |
| ModelResponse object. | |
| Raises: | |
| LLMResponseError: When LLM response error occurs. | |
| """ | |
| if not self.provider: | |
| raise RuntimeError( | |
| "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
| try: | |
| start_time = time.time() | |
| message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs) | |
| timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180)) | |
| result = self._pull_chat_result(message_key, response, timeout) | |
| logger.info(f"completion cost time: {time.time() - start_time}s.") | |
| resp = self.postprocess_response(result) | |
| usage_process(resp.usage) | |
| return resp | |
| except Exception as e: | |
| if isinstance(e, LLMResponseError): | |
| raise e | |
| logger.warn(f"Error in Ant completion: {e}") | |
| raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
| async def acompletion(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs) -> ModelResponse: | |
| """Asynchronously call Ant to generate response. | |
| Args: | |
| messages: Message list. | |
| temperature: Temperature parameter. | |
| max_tokens: Maximum number of tokens to generate. | |
| stop: List of stop sequences. | |
| **kwargs: Other parameters. | |
| Returns: | |
| ModelResponse object. | |
| Raises: | |
| LLMResponseError: When LLM response error occurs. | |
| """ | |
| if not self.async_provider: | |
| self._init_async_provider() | |
| start_time = time.time() | |
| try: | |
| message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs) | |
| timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180)) | |
| result = await self._async_pull_chat_result(message_key, response, timeout) | |
| logger.info(f"completion cost time: {time.time() - start_time}s.") | |
| resp = self.postprocess_response(result) | |
| usage_process(resp.usage) | |
| return resp | |
| except Exception as e: | |
| if isinstance(e, LLMResponseError): | |
| raise e | |
| logger.warn(f"Error in async Ant completion: {e}") | |
| raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
| def stream_completion(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs) -> Generator[ModelResponse, None, None]: | |
| """Synchronously call Ant to generate streaming response. | |
| Args: | |
| messages: Message list. | |
| temperature: Temperature parameter. | |
| max_tokens: Maximum number of tokens to generate. | |
| stop: List of stop sequences. | |
| **kwargs: Other parameters. | |
| Returns: | |
| Generator yielding ModelResponse chunks. | |
| Raises: | |
| LLMResponseError: When LLM response error occurs. | |
| """ | |
| if not self.provider: | |
| raise RuntimeError( | |
| "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
| start_time = time.time() | |
| # Generate message_key | |
| timestamp = int(time.time()) | |
| self.message_key = f"llm_call_{timestamp}" | |
| message_key_literal = self.message_key # Ensure it's a direct string literal | |
| self.aes_key = kwargs.get("aes_key", self.aes_key) | |
| # Add streaming parameter | |
| kwargs["stream"] = True | |
| processed_messages = self.preprocess_stream_call_message(messages, | |
| self._build_openai_params(temperature, max_tokens, | |
| stop, **kwargs)) | |
| if not processed_messages: | |
| raise LLMResponseError("Failed to get post data", self.model_name or "unknown") | |
| usage = { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| } | |
| try: | |
| # Send request | |
| # response = self.http_provider.sync_call(processed_messages[0], endpoint="commonQuery/queryData") | |
| headers = { | |
| "Content-Type": "application/json", | |
| "X_ACCESS_KEY": self.stream_api_key | |
| } | |
| response_stream = self.http_provider.sync_stream_call(processed_messages, endpoint="chat/completions", | |
| headers=headers) | |
| if response_stream: | |
| for chunk in response_stream: | |
| if not chunk: | |
| continue | |
| # Process special markers | |
| if isinstance(chunk, dict) and "status" in chunk: | |
| if chunk["status"] == "done": | |
| # Stream completion marker, can choose to end | |
| logger.info("Received [DONE] marker, stream completed") | |
| yield self._convert_completion_message(chunk, is_finished=True) | |
| yield ModelResponse.from_special_marker("done", self.model_name, chunk) | |
| break | |
| elif chunk["status"] == "revoke": | |
| # Revoke marker, need to notify the frontend to revoke the displayed content | |
| logger.info("Received [REVOKE] marker, content should be revoked") | |
| yield ModelResponse.from_special_marker("revoke", self.model_name, chunk) | |
| continue | |
| elif chunk["status"] == "fail": | |
| # Fail marker | |
| logger.error("Received [FAIL] marker, request failed") | |
| raise LLMResponseError("Request failed", self.model_name or "unknown") | |
| elif chunk["status"] == "cancel": | |
| # Request was cancelled | |
| logger.warning("Received [CANCEL] marker, stream was cancelled") | |
| raise LLMResponseError("Stream was cancelled", self.model_name or "unknown") | |
| continue | |
| # Process normal response chunks | |
| resp = self.postprocess_stream_response(chunk) | |
| self._accumulate_chunk_usage(usage, resp.usage) | |
| yield resp | |
| usage_process(usage) | |
| logger.info(f"stream_completion cost time: {time.time() - start_time}s.") | |
| except Exception as e: | |
| if isinstance(e, LLMResponseError): | |
| raise e | |
| logger.error(f"Error in Ant stream completion: {e}") | |
| raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
| async def astream_completion(self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.0, | |
| max_tokens: int = None, | |
| stop: List[str] = None, | |
| **kwargs) -> AsyncGenerator[ModelResponse, None]: | |
| """Asynchronously call Ant to generate streaming response. | |
| Args: | |
| messages: Message list. | |
| temperature: Temperature parameter. | |
| max_tokens: Maximum number of tokens to generate. | |
| stop: List of stop sequences. | |
| **kwargs: Other parameters. | |
| Returns: | |
| AsyncGenerator yielding ModelResponse chunks. | |
| Raises: | |
| LLMResponseError: When LLM response error occurs. | |
| """ | |
| if not self.async_provider: | |
| self._init_async_provider() | |
| start_time = time.time() | |
| # Generate message_key | |
| timestamp = int(time.time()) | |
| self.message_key = f"llm_call_{timestamp}" | |
| message_key_literal = self.message_key # Ensure it's a direct string literal | |
| self.aes_key = kwargs.get("aes_key", self.aes_key) | |
| # Add streaming parameter | |
| kwargs["stream"] = True | |
| processed_messages = self.preprocess_stream_call_message(messages, | |
| self._build_openai_params(temperature, max_tokens, | |
| stop, **kwargs)) | |
| if not processed_messages: | |
| raise LLMResponseError("Failed to get post data", self.model_name or "unknown") | |
| usage = { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| } | |
| try: | |
| headers = { | |
| "Content-Type": "application/json", | |
| "X_ACCESS_KEY": self.stream_api_key | |
| } | |
| logger.info(f"astream_completion request data: {processed_messages}") | |
| async for chunk in self.http_provider.async_stream_call(processed_messages, endpoint="chat/completions", | |
| headers=headers): | |
| if not chunk: | |
| continue | |
| # Process special markers | |
| if isinstance(chunk, dict) and "status" in chunk: | |
| if chunk["status"] == "done": | |
| # Stream completion marker, can choose to end | |
| logger.info("Received [DONE] marker, stream completed") | |
| yield ModelResponse.from_special_marker("done", self.model_name, chunk) | |
| break | |
| elif chunk["status"] == "revoke": | |
| # Revoke marker, need to notify the frontend to revoke the displayed content | |
| logger.info("Received [REVOKE] marker, content should be revoked") | |
| yield ModelResponse.from_special_marker("revoke", self.model_name, chunk) | |
| continue | |
| elif chunk["status"] == "fail": | |
| # Fail marker | |
| logger.error("Received [FAIL] marker, request failed") | |
| raise LLMResponseError("Request failed", self.model_name or "unknown") | |
| elif chunk["status"] == "cancel": | |
| # Request was cancelled | |
| logger.warning("Received [CANCEL] marker, stream was cancelled") | |
| raise LLMResponseError("Stream was cancelled", self.model_name or "unknown") | |
| continue | |
| # Process normal response chunks | |
| resp = self.postprocess_stream_response(chunk) | |
| self._accumulate_chunk_usage(usage, resp.usage) | |
| yield resp | |
| usage_process(usage) | |
| logger.info(f"astream_completion cost time: {time.time() - start_time}s.") | |
| except Exception as e: | |
| if isinstance(e, LLMResponseError): | |
| raise e | |
| logger.warn(f"Error in async Ant stream completion: {e}") | |
| raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
 
			
