Spaces:
Sleeping
Sleeping
from typing import List, Dict, Optional, Union | |
from openai import OpenAI | |
from openai.types.chat import ChatCompletion | |
from openai._types import NotGiven, NOT_GIVEN | |
import openai | |
import requests | |
from dataclasses import dataclass | |
from enum import Enum | |
class ModelProvider(Enum): | |
OPENAI = "openai" | |
DEEPSEEK = "deepseek" | |
class ApiConfig: | |
base_url: str | |
api_version: str = "v1" | |
def chat_endpoint(self) -> str: | |
return f"{self.base_url}/{self.api_version}/chat/completions" | |
class CustomOpenAI(OpenAI): | |
# API Configuration for different providers | |
API_CONFIGS = { | |
ModelProvider.OPENAI: ApiConfig("https://api.openai.com"), | |
ModelProvider.DEEPSEEK: ApiConfig("https://api.deepseek.com") | |
} | |
def __init__( | |
self, | |
api_key: Optional[str] = None, | |
organization: Optional[str] = None, | |
deepseek_api_key: Optional[str] = None, | |
**kwargs | |
): | |
""" | |
Initialize CustomOpenAI client with enhanced chat functionality. | |
Args: | |
api_key: OpenAI API key | |
organization: Organization ID (optional) | |
deepseek_api_key: DeepSeek API key (optional) | |
**kwargs: Additional client configuration parameters | |
""" | |
super().__init__( | |
api_key=api_key, | |
organization=organization, | |
**kwargs | |
) | |
self.deepseek_api_key = deepseek_api_key | |
def simple_chat( | |
self, | |
messages: List[Dict[str, str]], | |
model: str = "gpt-3.5-turbo", | |
temperature: float = 0.7, | |
max_tokens: Optional[int] = None, | |
**kwargs | |
) -> Optional[str]: | |
""" | |
Simplified chat completion method that returns just the message content. | |
Args: | |
messages: List of message dictionaries with 'role' and 'content' | |
model: Model identifier to use | |
temperature: Sampling temperature (0-2) | |
max_tokens: Maximum number of tokens to generate | |
**kwargs: Additional parameters to pass to the API | |
Returns: | |
Generated message content or None if an error occurs | |
""" | |
try: | |
# Prepare parameters | |
params = { | |
"model": model, | |
"messages": messages, | |
"temperature": temperature, | |
} | |
# Add max_tokens only if specified | |
if max_tokens is not None: | |
params["max_tokens"] = max_tokens | |
# Add any additional kwargs | |
params.update(kwargs) | |
# Make the API call using the inherited chat completions method | |
response: ChatCompletion = self.chat.completions.create(**params) | |
# Extract the message content from the first choice | |
if response.choices and len(response.choices) > 0: | |
return response.choices[0].message.content | |
return None | |
except openai.APIError as e: | |
print(f"OpenAI API Error: {str(e)}") | |
return None | |
except Exception as e: | |
print(f"Unexpected error: {str(e)}") | |
return None | |
def openai_chat( | |
self, | |
messages: List[Dict[str, str]], | |
model: str = "gpt-4-mini", | |
temperature: float = 0.7, | |
max_tokens: Optional[int] = None, | |
**kwargs | |
) -> Optional[str]: | |
""" | |
Chat completion method for OpenAI models. | |
Args: | |
messages: List of message dictionaries with 'role' and 'content' | |
model: OpenAI model identifier | |
temperature: Sampling temperature (0-2) | |
max_tokens: Maximum number of tokens to generate | |
**kwargs: Additional parameters to pass to the API | |
Returns: | |
Generated message content or None if an error occurs | |
""" | |
if not self.api_key: | |
raise ValueError("OpenAI API key is required for openai_chat") | |
try: | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {self.api_key}", | |
"OpenAI-Organization": self.organization if self.organization else "" | |
} | |
data = { | |
"model": model, | |
"messages": messages, | |
"temperature": temperature, | |
"stream": False | |
} | |
if max_tokens is not None: | |
data["max_tokens"] = max_tokens | |
# Add any additional kwargs to the request data | |
data.update(kwargs) | |
config = self.API_CONFIGS[ModelProvider.OPENAI] | |
response = requests.post( | |
config.chat_endpoint, | |
headers=headers, | |
json=data, | |
verify=False, # SSL検証を無効化(開発環境のみ) | |
timeout=(10, 60) # 接続タイムアウト10秒、読み取りタイムアウト60秒に延長 | |
) | |
# SSL検証を無効にした警告を抑制 | |
import urllib3 | |
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
if response.status_code != 200: | |
error_msg = f"OpenAI API request failed with status {response.status_code}" | |
try: | |
error_data = response.json() | |
print(f"[DEBUG] Error response: {error_data}") | |
if "error" in error_data: | |
error_msg += f": {error_data['error']}" | |
except Exception as e: | |
print(f"[DEBUG] Failed to parse error response: {str(e)}") | |
raise ValueError(error_msg) | |
response_data = response.json() | |
if not response_data.get("choices"): | |
raise ValueError("No choices in OpenAI API response") | |
return response_data["choices"][0]["message"]["content"] | |
except requests.exceptions.RequestException as e: | |
print(f"Network error during OpenAI API call: {str(e)}") | |
return None | |
except ValueError as e: | |
print(f"OpenAI API Error: {str(e)}") | |
return None | |
except Exception as e: | |
print(f"Unexpected error in OpenAI chat: {str(e)}") | |
return None | |
def chat_with_retry( | |
self, | |
messages: List[Dict[str, str]], | |
provider: ModelProvider = ModelProvider.OPENAI, | |
max_retries: int = 3, | |
**kwargs | |
) -> Optional[str]: | |
""" | |
Chat completion with automatic retry on failure. | |
Args: | |
messages: List of message dictionaries | |
provider: Model provider (OPENAI or DEEPSEEK) | |
max_retries: Maximum number of retry attempts | |
**kwargs: Additional parameters for chat methods | |
Returns: | |
Generated message content or None if all retries fail | |
""" | |
chat_method = { | |
ModelProvider.OPENAI: self.openai_chat, | |
ModelProvider.DEEPSEEK: self.deepseek_chat | |
}.get(provider, self.simple_chat) | |
for attempt in range(max_retries): | |
try: | |
result = chat_method(messages, **kwargs) | |
if result is not None: | |
return result | |
except Exception as e: | |
if attempt == max_retries - 1: | |
print(f"Failed after {max_retries} attempts: {str(e)}") | |
return None | |
print(f"Attempt {attempt + 1} failed, retrying...") | |
return None | |