Spaces:
Sleeping
Sleeping
File size: 7,774 Bytes
2cf6258 fff54c0 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 d0d1766 2cf6258 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
from typing import List, Dict, Optional, Union
from openai import OpenAI
from openai.types.chat import ChatCompletion
from openai._types import NotGiven, NOT_GIVEN
import openai
import requests
from dataclasses import dataclass
from enum import Enum
class ModelProvider(Enum):
OPENAI = "openai"
DEEPSEEK = "deepseek"
@dataclass
class ApiConfig:
base_url: str
api_version: str = "v1"
@property
def chat_endpoint(self) -> str:
return f"{self.base_url}/{self.api_version}/chat/completions"
class CustomOpenAI(OpenAI):
# API Configuration for different providers
API_CONFIGS = {
ModelProvider.OPENAI: ApiConfig("https://api.openai.com"),
ModelProvider.DEEPSEEK: ApiConfig("https://api.deepseek.com")
}
def __init__(
self,
api_key: Optional[str] = None,
organization: Optional[str] = None,
deepseek_api_key: Optional[str] = None,
**kwargs
):
"""
Initialize CustomOpenAI client with enhanced chat functionality.
Args:
api_key: OpenAI API key
organization: Organization ID (optional)
deepseek_api_key: DeepSeek API key (optional)
**kwargs: Additional client configuration parameters
"""
super().__init__(
api_key=api_key,
organization=organization,
**kwargs
)
self.deepseek_api_key = deepseek_api_key
def simple_chat(
self,
messages: List[Dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Optional[str]:
"""
Simplified chat completion method that returns just the message content.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier to use
temperature: Sampling temperature (0-2)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional parameters to pass to the API
Returns:
Generated message content or None if an error occurs
"""
try:
# Prepare parameters
params = {
"model": model,
"messages": messages,
"temperature": temperature,
}
# Add max_tokens only if specified
if max_tokens is not None:
params["max_tokens"] = max_tokens
# Add any additional kwargs
params.update(kwargs)
# Make the API call using the inherited chat completions method
response: ChatCompletion = self.chat.completions.create(**params)
# Extract the message content from the first choice
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
return None
except openai.APIError as e:
print(f"OpenAI API Error: {str(e)}")
return None
except Exception as e:
print(f"Unexpected error: {str(e)}")
return None
def openai_chat(
self,
messages: List[Dict[str, str]],
model: str = "gpt-4-mini",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Optional[str]:
"""
Chat completion method for OpenAI models.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: OpenAI model identifier
temperature: Sampling temperature (0-2)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional parameters to pass to the API
Returns:
Generated message content or None if an error occurs
"""
if not self.api_key:
raise ValueError("OpenAI API key is required for openai_chat")
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"OpenAI-Organization": self.organization if self.organization else ""
}
data = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": False
}
if max_tokens is not None:
data["max_tokens"] = max_tokens
# Add any additional kwargs to the request data
data.update(kwargs)
config = self.API_CONFIGS[ModelProvider.OPENAI]
response = requests.post(
config.chat_endpoint,
headers=headers,
json=data,
verify=False, # SSL検証を無効化(開発環境のみ)
timeout=(10, 60) # 接続タイムアウト10秒、読み取りタイムアウト60秒に延長
)
# SSL検証を無効にした警告を抑制
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if response.status_code != 200:
error_msg = f"OpenAI API request failed with status {response.status_code}"
try:
error_data = response.json()
print(f"[DEBUG] Error response: {error_data}")
if "error" in error_data:
error_msg += f": {error_data['error']}"
except Exception as e:
print(f"[DEBUG] Failed to parse error response: {str(e)}")
raise ValueError(error_msg)
response_data = response.json()
if not response_data.get("choices"):
raise ValueError("No choices in OpenAI API response")
return response_data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
print(f"Network error during OpenAI API call: {str(e)}")
return None
except ValueError as e:
print(f"OpenAI API Error: {str(e)}")
return None
except Exception as e:
print(f"Unexpected error in OpenAI chat: {str(e)}")
return None
def chat_with_retry(
self,
messages: List[Dict[str, str]],
provider: ModelProvider = ModelProvider.OPENAI,
max_retries: int = 3,
**kwargs
) -> Optional[str]:
"""
Chat completion with automatic retry on failure.
Args:
messages: List of message dictionaries
provider: Model provider (OPENAI or DEEPSEEK)
max_retries: Maximum number of retry attempts
**kwargs: Additional parameters for chat methods
Returns:
Generated message content or None if all retries fail
"""
chat_method = {
ModelProvider.OPENAI: self.openai_chat,
ModelProvider.DEEPSEEK: self.deepseek_chat
}.get(provider, self.simple_chat)
for attempt in range(max_retries):
try:
result = chat_method(messages, **kwargs)
if result is not None:
return result
except Exception as e:
if attempt == max_retries - 1:
print(f"Failed after {max_retries} attempts: {str(e)}")
return None
print(f"Attempt {attempt + 1} failed, retrying...")
return None
|