AILanguageCompanion / custom_openai_client.py
koura718's picture
Change custom_openai_client and deepseek_client
d0d1766
from typing import List, Dict, Optional, Union
from openai import OpenAI
from openai.types.chat import ChatCompletion
from openai._types import NotGiven, NOT_GIVEN
import openai
import requests
from dataclasses import dataclass
from enum import Enum
class ModelProvider(Enum):
OPENAI = "openai"
DEEPSEEK = "deepseek"
@dataclass
class ApiConfig:
base_url: str
api_version: str = "v1"
@property
def chat_endpoint(self) -> str:
return f"{self.base_url}/{self.api_version}/chat/completions"
class CustomOpenAI(OpenAI):
# API Configuration for different providers
API_CONFIGS = {
ModelProvider.OPENAI: ApiConfig("https://api.openai.com"),
ModelProvider.DEEPSEEK: ApiConfig("https://api.deepseek.com")
}
def __init__(
self,
api_key: Optional[str] = None,
organization: Optional[str] = None,
deepseek_api_key: Optional[str] = None,
**kwargs
):
"""
Initialize CustomOpenAI client with enhanced chat functionality.
Args:
api_key: OpenAI API key
organization: Organization ID (optional)
deepseek_api_key: DeepSeek API key (optional)
**kwargs: Additional client configuration parameters
"""
super().__init__(
api_key=api_key,
organization=organization,
**kwargs
)
self.deepseek_api_key = deepseek_api_key
def simple_chat(
self,
messages: List[Dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Optional[str]:
"""
Simplified chat completion method that returns just the message content.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier to use
temperature: Sampling temperature (0-2)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional parameters to pass to the API
Returns:
Generated message content or None if an error occurs
"""
try:
# Prepare parameters
params = {
"model": model,
"messages": messages,
"temperature": temperature,
}
# Add max_tokens only if specified
if max_tokens is not None:
params["max_tokens"] = max_tokens
# Add any additional kwargs
params.update(kwargs)
# Make the API call using the inherited chat completions method
response: ChatCompletion = self.chat.completions.create(**params)
# Extract the message content from the first choice
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
return None
except openai.APIError as e:
print(f"OpenAI API Error: {str(e)}")
return None
except Exception as e:
print(f"Unexpected error: {str(e)}")
return None
def openai_chat(
self,
messages: List[Dict[str, str]],
model: str = "gpt-4-mini",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Optional[str]:
"""
Chat completion method for OpenAI models.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: OpenAI model identifier
temperature: Sampling temperature (0-2)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional parameters to pass to the API
Returns:
Generated message content or None if an error occurs
"""
if not self.api_key:
raise ValueError("OpenAI API key is required for openai_chat")
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"OpenAI-Organization": self.organization if self.organization else ""
}
data = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": False
}
if max_tokens is not None:
data["max_tokens"] = max_tokens
# Add any additional kwargs to the request data
data.update(kwargs)
config = self.API_CONFIGS[ModelProvider.OPENAI]
response = requests.post(
config.chat_endpoint,
headers=headers,
json=data,
verify=False, # SSL検証を無効化(開発環境のみ)
timeout=(10, 60) # 接続タイムアウト10秒、読み取りタイムアウト60秒に延長
)
# SSL検証を無効にした警告を抑制
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if response.status_code != 200:
error_msg = f"OpenAI API request failed with status {response.status_code}"
try:
error_data = response.json()
print(f"[DEBUG] Error response: {error_data}")
if "error" in error_data:
error_msg += f": {error_data['error']}"
except Exception as e:
print(f"[DEBUG] Failed to parse error response: {str(e)}")
raise ValueError(error_msg)
response_data = response.json()
if not response_data.get("choices"):
raise ValueError("No choices in OpenAI API response")
return response_data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
print(f"Network error during OpenAI API call: {str(e)}")
return None
except ValueError as e:
print(f"OpenAI API Error: {str(e)}")
return None
except Exception as e:
print(f"Unexpected error in OpenAI chat: {str(e)}")
return None
def chat_with_retry(
self,
messages: List[Dict[str, str]],
provider: ModelProvider = ModelProvider.OPENAI,
max_retries: int = 3,
**kwargs
) -> Optional[str]:
"""
Chat completion with automatic retry on failure.
Args:
messages: List of message dictionaries
provider: Model provider (OPENAI or DEEPSEEK)
max_retries: Maximum number of retry attempts
**kwargs: Additional parameters for chat methods
Returns:
Generated message content or None if all retries fail
"""
chat_method = {
ModelProvider.OPENAI: self.openai_chat,
ModelProvider.DEEPSEEK: self.deepseek_chat
}.get(provider, self.simple_chat)
for attempt in range(max_retries):
try:
result = chat_method(messages, **kwargs)
if result is not None:
return result
except Exception as e:
if attempt == max_retries - 1:
print(f"Failed after {max_retries} attempts: {str(e)}")
return None
print(f"Attempt {attempt + 1} failed, retrying...")
return None