Spaces:
Sleeping
Sleeping
File size: 7,707 Bytes
2cf6258 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
from typing import List, Dict, Optional, Union
from openai import OpenAI
from openai.types.chat import ChatCompletion
from openai._types import NotGiven, NOT_GIVEN
import openai
import requests
from dataclasses import dataclass
from enum import Enum
class ModelProvider(Enum):
OPENAI = "openai"
DEEPSEEK = "deepseek"
@dataclass
class ApiConfig:
base_url: str
api_version: str = "v1"
@property
def chat_endpoint(self) -> str:
return f"{self.base_url}/{self.api_version}/chat/completions"
class CustomOpenAI(OpenAI):
# API Configuration for different providers
API_CONFIGS = {
ModelProvider.OPENAI: ApiConfig("https://api.openai.com"),
ModelProvider.DEEPSEEK: ApiConfig("https://api.deepseek.ai")
}
def __init__(
self,
api_key: Optional[str] = None,
organization: Optional[str] = None,
deepseek_api_key: Optional[str] = None,
**kwargs
):
"""
Initialize CustomOpenAI client with enhanced chat functionality.
Args:
api_key: OpenAI API key
organization: Organization ID (optional)
deepseek_api_key: DeepSeek API key (optional)
**kwargs: Additional client configuration parameters
"""
super().__init__(
api_key=api_key,
organization=organization,
**kwargs
)
self.deepseek_api_key = deepseek_api_key
def simple_chat(
self,
messages: List[Dict[str, str]],
model: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Optional[str]:
"""
Simplified chat completion method that returns just the message content.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier to use
temperature: Sampling temperature (0-2)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional parameters to pass to the API
Returns:
Generated message content or None if an error occurs
"""
try:
# Prepare parameters
params = {
"model": model,
"messages": messages,
"temperature": temperature,
}
# Add max_tokens only if specified
if max_tokens is not None:
params["max_tokens"] = max_tokens
# Add any additional kwargs
params.update(kwargs)
# Make the API call using the inherited chat completions method
response: ChatCompletion = self.chat.completions.create(**params)
# Extract the message content from the first choice
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
return None
except openai.APIError as e:
print(f"OpenAI API Error: {str(e)}")
return None
except Exception as e:
print(f"Unexpected error: {str(e)}")
return None
def deepseek_chat(
self,
messages: List[Dict[str, str]],
model: str = "deepseek-chat",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Optional[str]:
"""
Chat completion method for DeepSeek models.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: DeepSeek model identifier
temperature: Sampling temperature (0-2)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional parameters to pass to the API
Returns:
Generated message content or None if an error occurs
"""
if not self.deepseek_api_key:
raise ValueError("DeepSeek API key is required for deepseek_chat")
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.deepseek_api_key}",
"User-Agent": "DeepseekClient/1.0"
}
data = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": False
}
if max_tokens is not None:
data["max_tokens"] = max_tokens
# Add any additional kwargs to the request data
data.update(kwargs)
config = self.API_CONFIGS[ModelProvider.DEEPSEEK]
response = requests.post(
config.chat_endpoint,
headers=headers,
json=data,
verify=False, # SSL検証を無効化(開発環境のみ)
timeout=(10, 60) # 接続タイムアウト10秒、読み取りタイムアウト60秒に延長
)
# SSL検証を無効にした警告を抑制
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if response.status_code != 200:
error_msg = f"DeepSeek API request failed with status {response.status_code}"
try:
error_data = response.json()
print(f"[DEBUG] Error response: {error_data}")
if "error" in error_data:
error_msg += f": {error_data['error']}"
except Exception as e:
print(f"[DEBUG] Failed to parse error response: {str(e)}")
raise ValueError(error_msg)
response_data = response.json()
if not response_data.get("choices"):
raise ValueError("No choices in DeepSeek API response")
return response_data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
print(f"Network error during DeepSeek API call: {str(e)}")
return None
except ValueError as e:
print(f"DeepSeek API Error: {str(e)}")
return None
except Exception as e:
print(f"Unexpected error in DeepSeek chat: {str(e)}")
return None
def chat_with_retry(
self,
messages: List[Dict[str, str]],
provider: ModelProvider = ModelProvider.OPENAI,
max_retries: int = 3,
**kwargs
) -> Optional[str]:
"""
Chat completion with automatic retry on failure.
Args:
messages: List of message dictionaries
provider: Model provider (OPENAI or DEEPSEEK)
max_retries: Maximum number of retry attempts
**kwargs: Additional parameters for chat methods
Returns:
Generated message content or None if all retries fail
"""
chat_method = self.simple_chat if provider == ModelProvider.OPENAI else self.deepseek_chat
for attempt in range(max_retries):
try:
result = chat_method(messages, **kwargs)
if result is not None:
return result
except Exception as e:
if attempt == max_retries - 1:
print(f"Failed after {max_retries} attempts: {str(e)}")
return None
print(f"Attempt {attempt + 1} failed, retrying...")
return None
|