|
|
import os |
|
|
from typing import Any, Dict, List, Generator, AsyncGenerator |
|
|
|
|
|
from openai import OpenAI, AsyncOpenAI |
|
|
|
|
|
from aworld.config.conf import ClientType |
|
|
from aworld.core.llm_provider_base import LLMProviderBase |
|
|
from aworld.models.llm_http_handler import LLMHTTPHandler |
|
|
from aworld.models.model_response import ModelResponse, LLMResponseError |
|
|
from aworld.logs.util import logger |
|
|
from aworld.models.utils import usage_process |
|
|
|
|
|
|
|
|
class OpenAIProvider(LLMProviderBase): |
|
|
"""OpenAI provider implementation. |
|
|
""" |
|
|
|
|
|
def _init_provider(self): |
|
|
"""Initialize OpenAI provider. |
|
|
|
|
|
Returns: |
|
|
OpenAI provider instance. |
|
|
""" |
|
|
|
|
|
api_key = self.api_key |
|
|
if not api_key: |
|
|
env_var = "OPENAI_API_KEY" |
|
|
api_key = os.getenv(env_var, "") |
|
|
if not api_key: |
|
|
raise ValueError( |
|
|
f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") |
|
|
base_url = self.base_url |
|
|
if not base_url: |
|
|
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") |
|
|
|
|
|
self.is_http_provider = False |
|
|
if self.kwargs.get("client_type", ClientType.SDK) == ClientType.HTTP: |
|
|
logger.info(f"Using HTTP provider for OpenAI") |
|
|
self.http_provider = LLMHTTPHandler( |
|
|
base_url=base_url, |
|
|
api_key=api_key, |
|
|
model_name=self.model_name, |
|
|
max_retries=self.kwargs.get("max_retries", 3) |
|
|
) |
|
|
self.is_http_provider = True |
|
|
return self.http_provider |
|
|
else: |
|
|
return OpenAI( |
|
|
api_key=api_key, |
|
|
base_url=base_url, |
|
|
timeout=self.kwargs.get("timeout", 180), |
|
|
max_retries=self.kwargs.get("max_retries", 3) |
|
|
) |
|
|
|
|
|
def _init_async_provider(self): |
|
|
"""Initialize async OpenAI provider. |
|
|
|
|
|
Returns: |
|
|
Async OpenAI provider instance. |
|
|
""" |
|
|
|
|
|
api_key = self.api_key |
|
|
if not api_key: |
|
|
env_var = "OPENAI_API_KEY" |
|
|
api_key = os.getenv(env_var, "") |
|
|
if not api_key: |
|
|
raise ValueError( |
|
|
f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") |
|
|
base_url = self.base_url |
|
|
if not base_url: |
|
|
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") |
|
|
|
|
|
return AsyncOpenAI( |
|
|
api_key=api_key, |
|
|
base_url=base_url, |
|
|
timeout=self.kwargs.get("timeout", 180), |
|
|
max_retries=self.kwargs.get("max_retries", 3) |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def supported_models(cls) -> list[str]: |
|
|
return ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini", "deepseek-chat", "deepseek-reasoner", |
|
|
r"qwq-.*", r"qwen-.*"] |
|
|
|
|
|
def preprocess_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: |
|
|
"""Preprocess messages, use OpenAI format directly. |
|
|
|
|
|
Args: |
|
|
messages: OpenAI format message list. |
|
|
|
|
|
Returns: |
|
|
Processed message list. |
|
|
""" |
|
|
for message in messages: |
|
|
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]: |
|
|
if message["content"] is None: message["content"] = "" |
|
|
for tool_call in message["tool_calls"]: |
|
|
if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call: |
|
|
tool_call["function"] = {"name": tool_call["name"], "arguments": tool_call["arguments"]} |
|
|
|
|
|
return messages |
|
|
|
|
|
def postprocess_response(self, response: Any) -> ModelResponse: |
|
|
"""Process OpenAI response. |
|
|
|
|
|
Args: |
|
|
response: OpenAI response object. |
|
|
|
|
|
Returns: |
|
|
ModelResponse object. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices)) |
|
|
or (isinstance(response, dict) and not response.get("choices"))): |
|
|
error_msg = "" |
|
|
if hasattr(response, 'error') and response.error and isinstance(response.error, dict): |
|
|
error_msg = response.error.get('message', '') |
|
|
elif hasattr(response, 'msg'): |
|
|
error_msg = response.msg |
|
|
|
|
|
raise LLMResponseError( |
|
|
error_msg if error_msg else "Unknown error", |
|
|
self.model_name or "unknown", |
|
|
response |
|
|
) |
|
|
|
|
|
return ModelResponse.from_openai_response(response) |
|
|
|
|
|
def postprocess_stream_response(self, chunk: Any) -> ModelResponse: |
|
|
"""Process OpenAI streaming response chunk. |
|
|
|
|
|
Args: |
|
|
chunk: OpenAI response chunk. |
|
|
|
|
|
Returns: |
|
|
ModelResponse object. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
|
|
|
if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')): |
|
|
error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error') |
|
|
raise LLMResponseError( |
|
|
error_msg, |
|
|
self.model_name or "unknown", |
|
|
chunk |
|
|
) |
|
|
|
|
|
|
|
|
if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.tool_calls) or ( |
|
|
isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get("delta", {}).get("tool_calls")): |
|
|
tool_calls = chunk.choices[0].delta.tool_calls if hasattr(chunk, 'choices') else chunk["choices"][0].get("delta", {}).get("tool_calls") |
|
|
|
|
|
for tool_call in tool_calls: |
|
|
index = tool_call.index if hasattr(tool_call, 'index') else tool_call["index"] |
|
|
func_name = tool_call.function.name if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("name") |
|
|
func_args = tool_call.function.arguments if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("arguments") |
|
|
if index >= len(self.stream_tool_buffer): |
|
|
self.stream_tool_buffer.append({ |
|
|
"id": tool_call.id if hasattr(tool_call, 'id') else tool_call.get("id"), |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": func_name, |
|
|
"arguments": func_args |
|
|
} |
|
|
}) |
|
|
else: |
|
|
self.stream_tool_buffer[index]["function"]["arguments"] += func_args |
|
|
processed_chunk = chunk |
|
|
if hasattr(processed_chunk, 'choices'): |
|
|
processed_chunk.choices[0].delta.tool_calls = None |
|
|
else: |
|
|
processed_chunk["choices"][0]["delta"]["tool_calls"] = None |
|
|
resp = ModelResponse.from_openai_stream_chunk(processed_chunk) |
|
|
if (not resp.content and not resp.usage.get("total_tokens", 0)): |
|
|
return None |
|
|
if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].finish_reason) or ( |
|
|
isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get( |
|
|
"finish_reason")): |
|
|
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk, 'choices') else chunk["choices"][0].get( |
|
|
"finish_reason") |
|
|
if self.stream_tool_buffer: |
|
|
tool_call_chunk = { |
|
|
"id": chunk.id if hasattr(chunk, 'id') else chunk.get("id"), |
|
|
"model": chunk.model if hasattr(chunk, 'model') else chunk.get("model"), |
|
|
"object": chunk.object if hasattr(chunk, 'object') else chunk.get("object"), |
|
|
"choices": [ |
|
|
{ |
|
|
"delta": { |
|
|
"role": "assistant", |
|
|
"content": "", |
|
|
"tool_calls": self.stream_tool_buffer |
|
|
} |
|
|
} |
|
|
] |
|
|
} |
|
|
self.stream_tool_buffer = [] |
|
|
return ModelResponse.from_openai_stream_chunk(tool_call_chunk) |
|
|
|
|
|
return ModelResponse.from_openai_stream_chunk(chunk) |
|
|
|
|
|
def completion(self, |
|
|
messages: List[Dict[str, str]], |
|
|
temperature: float = 0.0, |
|
|
max_tokens: int = None, |
|
|
stop: List[str] = None, |
|
|
**kwargs) -> ModelResponse: |
|
|
"""Synchronously call OpenAI to generate response. |
|
|
|
|
|
Args: |
|
|
messages: Message list. |
|
|
temperature: Temperature parameter. |
|
|
max_tokens: Maximum number of tokens to generate. |
|
|
stop: List of stop sequences. |
|
|
**kwargs: Other parameters. |
|
|
|
|
|
Returns: |
|
|
ModelResponse object. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if not self.provider: |
|
|
raise RuntimeError( |
|
|
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") |
|
|
|
|
|
processed_messages = self.preprocess_messages(messages) |
|
|
|
|
|
try: |
|
|
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) |
|
|
if self.is_http_provider: |
|
|
response = self.http_provider.sync_call(openai_params) |
|
|
else: |
|
|
response = self.provider.chat.completions.create(**openai_params) |
|
|
|
|
|
if (hasattr(response, 'code') and response.code != 0) or ( |
|
|
isinstance(response, dict) and response.get("code", 0) != 0): |
|
|
error_msg = getattr(response, 'msg', 'Unknown error') |
|
|
logger.warn(f"API Error: {error_msg}") |
|
|
raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response) |
|
|
|
|
|
if not response: |
|
|
raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown")) |
|
|
|
|
|
resp = self.postprocess_response(response) |
|
|
usage_process(resp.usage) |
|
|
return resp |
|
|
except Exception as e: |
|
|
if isinstance(e, LLMResponseError): |
|
|
raise e |
|
|
logger.warn(f"Error in OpenAI completion: {e}") |
|
|
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) |
|
|
|
|
|
def stream_completion(self, |
|
|
messages: List[Dict[str, str]], |
|
|
temperature: float = 0.0, |
|
|
max_tokens: int = None, |
|
|
stop: List[str] = None, |
|
|
**kwargs) -> Generator[ModelResponse, None, None]: |
|
|
"""Synchronously call OpenAI to generate streaming response. |
|
|
|
|
|
Args: |
|
|
messages: Message list. |
|
|
temperature: Temperature parameter. |
|
|
max_tokens: Maximum number of tokens to generate. |
|
|
stop: List of stop sequences. |
|
|
**kwargs: Other parameters. |
|
|
|
|
|
Returns: |
|
|
Generator yielding ModelResponse chunks. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if not self.provider: |
|
|
raise RuntimeError( |
|
|
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") |
|
|
|
|
|
processed_messages = self.preprocess_messages(messages) |
|
|
usage={ |
|
|
"completion_tokens": 0, |
|
|
"prompt_tokens": 0, |
|
|
"total_tokens": 0 |
|
|
} |
|
|
|
|
|
try: |
|
|
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) |
|
|
openai_params["stream"] = True |
|
|
if self.is_http_provider: |
|
|
response_stream = self.http_provider.sync_stream_call(openai_params) |
|
|
else: |
|
|
response_stream = self.provider.chat.completions.create(**openai_params) |
|
|
|
|
|
for chunk in response_stream: |
|
|
if not chunk: |
|
|
continue |
|
|
resp = self.postprocess_stream_response(chunk) |
|
|
if resp: |
|
|
self._accumulate_chunk_usage(usage, resp.usage) |
|
|
yield resp |
|
|
usage_process(usage) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warn(f"Error in stream_completion: {e}") |
|
|
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) |
|
|
|
|
|
async def astream_completion(self, |
|
|
messages: List[Dict[str, str]], |
|
|
temperature: float = 0.0, |
|
|
max_tokens: int = None, |
|
|
stop: List[str] = None, |
|
|
**kwargs) -> AsyncGenerator[ModelResponse, None]: |
|
|
"""Asynchronously call OpenAI to generate streaming response. |
|
|
|
|
|
Args: |
|
|
messages: Message list. |
|
|
temperature: Temperature parameter. |
|
|
max_tokens: Maximum number of tokens to generate. |
|
|
stop: List of stop sequences. |
|
|
**kwargs: Other parameters. |
|
|
|
|
|
Returns: |
|
|
AsyncGenerator yielding ModelResponse chunks. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if not self.async_provider: |
|
|
raise RuntimeError( |
|
|
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") |
|
|
|
|
|
processed_messages = self.preprocess_messages(messages) |
|
|
usage = { |
|
|
"completion_tokens": 0, |
|
|
"prompt_tokens": 0, |
|
|
"total_tokens": 0 |
|
|
} |
|
|
|
|
|
try: |
|
|
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) |
|
|
openai_params["stream"] = True |
|
|
|
|
|
if self.is_http_provider: |
|
|
async for chunk in self.http_provider.async_stream_call(openai_params): |
|
|
if not chunk: |
|
|
continue |
|
|
resp = self.postprocess_stream_response(chunk) |
|
|
self._accumulate_chunk_usage(usage, resp.usage) |
|
|
yield resp |
|
|
else: |
|
|
response_stream = await self.async_provider.chat.completions.create(**openai_params) |
|
|
async for chunk in response_stream: |
|
|
if not chunk: |
|
|
continue |
|
|
resp = self.postprocess_stream_response(chunk) |
|
|
if resp: |
|
|
self._accumulate_chunk_usage(usage, resp.usage) |
|
|
yield resp |
|
|
usage_process(usage) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warn(f"Error in astream_completion: {e}") |
|
|
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) |
|
|
|
|
|
async def acompletion(self, |
|
|
messages: List[Dict[str, str]], |
|
|
temperature: float = 0.0, |
|
|
max_tokens: int = None, |
|
|
stop: List[str] = None, |
|
|
**kwargs) -> ModelResponse: |
|
|
"""Asynchronously call OpenAI to generate response. |
|
|
|
|
|
Args: |
|
|
messages: Message list. |
|
|
temperature: Temperature parameter. |
|
|
max_tokens: Maximum number of tokens to generate. |
|
|
stop: List of stop sequences. |
|
|
**kwargs: Other parameters. |
|
|
|
|
|
Returns: |
|
|
ModelResponse object. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if not self.async_provider: |
|
|
raise RuntimeError( |
|
|
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") |
|
|
|
|
|
processed_messages = self.preprocess_messages(messages) |
|
|
|
|
|
try: |
|
|
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) |
|
|
if self.is_http_provider: |
|
|
response = await self.http_provider.async_call(openai_params) |
|
|
else: |
|
|
response = await self.async_provider.chat.completions.create(**openai_params) |
|
|
|
|
|
if (hasattr(response, 'code') and response.code != 0) or ( |
|
|
isinstance(response, dict) and response.get("code", 0) != 0): |
|
|
error_msg = getattr(response, 'msg', 'Unknown error') |
|
|
logger.warn(f"API Error: {error_msg}") |
|
|
raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response) |
|
|
|
|
|
if not response: |
|
|
raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown")) |
|
|
|
|
|
resp = self.postprocess_response(response) |
|
|
usage_process(resp.usage) |
|
|
return resp |
|
|
except Exception as e: |
|
|
if isinstance(e, LLMResponseError): |
|
|
raise e |
|
|
logger.warn(f"Error in acompletion: {e}") |
|
|
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) |
|
|
|
|
|
def get_openai_params(self, |
|
|
messages: List[Dict[str, str]], |
|
|
temperature: float = 0.0, |
|
|
max_tokens: int = None, |
|
|
stop: List[str] = None, |
|
|
**kwargs) -> Dict[str, Any]: |
|
|
openai_params = { |
|
|
"model": kwargs.get("model_name", self.model_name or ""), |
|
|
"messages": messages, |
|
|
"temperature": temperature, |
|
|
"max_tokens": max_tokens, |
|
|
"stop": stop |
|
|
} |
|
|
|
|
|
supported_params = [ |
|
|
"max_completion_tokens", "meta_data", "modalities", "n", "parallel_tool_calls", |
|
|
"prediction", "reasoning_effort", "service_tier", "stream_options", "web_search_options" |
|
|
"frequency_penalty", "logit_bias", "logprobs", "top_logprobs", |
|
|
"presence_penalty", "response_format", "seed", "stream", "top_p", |
|
|
"user", "function_call", "functions", "tools", "tool_choice" |
|
|
] |
|
|
|
|
|
for param in supported_params: |
|
|
if param in kwargs: |
|
|
openai_params[param] = kwargs[param] |
|
|
|
|
|
return openai_params |
|
|
|
|
|
def speech_to_text(self, |
|
|
audio_file: str, |
|
|
language: str = None, |
|
|
prompt: str = None, |
|
|
**kwargs) -> ModelResponse: |
|
|
"""Convert speech to text. |
|
|
|
|
|
Uses OpenAI's speech-to-text API to convert audio files to text. |
|
|
|
|
|
Args: |
|
|
audio_file: Path to audio file or file object. |
|
|
language: Audio language, optional. |
|
|
prompt: Transcription prompt, optional. |
|
|
**kwargs: Other parameters, may include: |
|
|
- model: Transcription model name, defaults to "whisper-1". |
|
|
- response_format: Response format, defaults to "text". |
|
|
- temperature: Sampling temperature, defaults to 0. |
|
|
|
|
|
Returns: |
|
|
ModelResponse: Unified model response object, with content field containing the transcription result. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if not self.provider: |
|
|
raise RuntimeError( |
|
|
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") |
|
|
|
|
|
try: |
|
|
|
|
|
transcription_params = { |
|
|
"model": kwargs.get("model", "whisper-1"), |
|
|
"response_format": kwargs.get("response_format", "text"), |
|
|
"temperature": kwargs.get("temperature", 0) |
|
|
} |
|
|
|
|
|
|
|
|
if language: |
|
|
transcription_params["language"] = language |
|
|
if prompt: |
|
|
transcription_params["prompt"] = prompt |
|
|
|
|
|
|
|
|
if isinstance(audio_file, str): |
|
|
with open(audio_file, "rb") as file: |
|
|
transcription_response = self.provider.audio.transcriptions.create( |
|
|
file=file, |
|
|
**transcription_params |
|
|
) |
|
|
else: |
|
|
|
|
|
transcription_response = self.provider.audio.transcriptions.create( |
|
|
file=audio_file, |
|
|
**transcription_params |
|
|
) |
|
|
|
|
|
|
|
|
return ModelResponse( |
|
|
id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}", |
|
|
model=transcription_params["model"], |
|
|
content=transcription_response.text if hasattr(transcription_response, 'text') else str( |
|
|
transcription_response), |
|
|
raw_response=transcription_response, |
|
|
message={ |
|
|
"role": "assistant", |
|
|
"content": transcription_response.text if hasattr(transcription_response, 'text') else str( |
|
|
transcription_response) |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warn(f"Speech-to-text error: {e}") |
|
|
raise LLMResponseError(str(e), kwargs.get("model", "whisper-1")) |
|
|
|
|
|
async def aspeech_to_text(self, |
|
|
audio_file: str, |
|
|
language: str = None, |
|
|
prompt: str = None, |
|
|
**kwargs) -> ModelResponse: |
|
|
"""Asynchronously convert speech to text. |
|
|
|
|
|
Uses OpenAI's speech-to-text API to convert audio files to text. |
|
|
|
|
|
Args: |
|
|
audio_file: Path to audio file or file object. |
|
|
language: Audio language, optional. |
|
|
prompt: Transcription prompt, optional. |
|
|
**kwargs: Other parameters, may include: |
|
|
- model: Transcription model name, defaults to "whisper-1". |
|
|
- response_format: Response format, defaults to "text". |
|
|
- temperature: Sampling temperature, defaults to 0. |
|
|
|
|
|
Returns: |
|
|
ModelResponse: Unified model response object, with content field containing the transcription result. |
|
|
|
|
|
Raises: |
|
|
LLMResponseError: When LLM response error occurs. |
|
|
""" |
|
|
if not self.async_provider: |
|
|
raise RuntimeError( |
|
|
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") |
|
|
|
|
|
try: |
|
|
|
|
|
transcription_params = { |
|
|
"model": kwargs.get("model", "whisper-1"), |
|
|
"response_format": kwargs.get("response_format", "text"), |
|
|
"temperature": kwargs.get("temperature", 0) |
|
|
} |
|
|
|
|
|
|
|
|
if language: |
|
|
transcription_params["language"] = language |
|
|
if prompt: |
|
|
transcription_params["prompt"] = prompt |
|
|
|
|
|
|
|
|
if isinstance(audio_file, str): |
|
|
with open(audio_file, "rb") as file: |
|
|
transcription_response = await self.async_provider.audio.transcriptions.create( |
|
|
file=file, |
|
|
**transcription_params |
|
|
) |
|
|
else: |
|
|
|
|
|
transcription_response = await self.async_provider.audio.transcriptions.create( |
|
|
file=audio_file, |
|
|
**transcription_params |
|
|
) |
|
|
|
|
|
|
|
|
return ModelResponse( |
|
|
id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}", |
|
|
model=transcription_params["model"], |
|
|
content=transcription_response.text if hasattr(transcription_response, 'text') else str( |
|
|
transcription_response), |
|
|
raw_response=transcription_response, |
|
|
message={ |
|
|
"role": "assistant", |
|
|
"content": transcription_response.text if hasattr(transcription_response, 'text') else str( |
|
|
transcription_response) |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warn(f"Async speech-to-text error: {e}") |
|
|
raise LLMResponseError(str(e), kwargs.get("model", "whisper-1")) |
|
|
|
|
|
|
|
|
class AzureOpenAIProvider(OpenAIProvider): |
|
|
"""Azure OpenAI provider implementation. |
|
|
""" |
|
|
|
|
|
def _init_provider(self): |
|
|
"""Initialize Azure OpenAI provider. |
|
|
|
|
|
Returns: |
|
|
Azure OpenAI provider instance. |
|
|
""" |
|
|
from langchain_openai import AzureChatOpenAI |
|
|
|
|
|
|
|
|
api_key = self.api_key |
|
|
if not api_key: |
|
|
env_var = "AZURE_OPENAI_API_KEY" |
|
|
api_key = os.getenv(env_var, "") |
|
|
if not api_key: |
|
|
raise ValueError( |
|
|
f"Azure OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") |
|
|
|
|
|
|
|
|
api_version = self.kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") |
|
|
|
|
|
|
|
|
azure_endpoint = self.base_url |
|
|
if not azure_endpoint: |
|
|
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") |
|
|
if not azure_endpoint: |
|
|
raise ValueError( |
|
|
"Azure OpenAI endpoint not found, please set AZURE_OPENAI_ENDPOINT environment variable or provide it in the parameters") |
|
|
|
|
|
return AzureChatOpenAI( |
|
|
model=self.model_name or "gpt-4o", |
|
|
temperature=self.kwargs.get("temperature", 0.0), |
|
|
api_version=api_version, |
|
|
azure_endpoint=azure_endpoint, |
|
|
api_key=api_key |
|
|
) |
|
|
|