GeminiBalance / app /service /chat /openai_chat_service.py
CatPtain's picture
Upload 77 files
76b9762 verified
# app/services/chat_service.py
import asyncio
import datetime
import json
import re
import time
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
from app.database.services import (
add_error_log,
add_request_log,
)
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
from app.handler.stream_optimizer import openai_optimizer
from app.log.logger import get_openai_logger
from app.service.client.api_client import GeminiApiClient
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
logger = get_openai_logger()
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
for content in contents:
if content and "parts" in content and isinstance(content["parts"], list):
for part in content["parts"]:
if isinstance(part, dict) and "inline_data" in part:
return True
return False
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tool = dict()
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (
model.endswith("-search")
or "-thinking" in model
or model.endswith("-image")
or model.endswith("-image-generation")
)
and not _has_media_parts(messages)
):
tool["codeExecution"] = {}
logger.debug("Code execution tool enabled.")
elif _has_media_parts(messages):
logger.debug("Code execution tool disabled due to media parts presence.")
if model.endswith("-search"):
tool["googleSearch"] = {}
# 将 request 中的 tools 合并到 tools 中
if request.tools:
function_declarations = []
for item in request.tools:
if not item or not isinstance(item, dict):
continue
if item.get("type", "") == "function" and item.get("function"):
function = deepcopy(item.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get(
"properties", {}
):
function.pop("parameters", None)
function_declarations.append(function)
if function_declarations:
# 按照 function 的 name 去重
names, functions = set(), []
for fc in function_declarations:
if fc.get("name") not in names:
if fc.get("name")=="googleSearch":
# cherry开启内置搜索时,添加googleSearch工具
tool["googleSearch"] = {}
else:
# 其他函数,添加到functionDeclarations中
names.add(fc.get("name"))
functions.append(fc)
tool["functionDeclarations"] = functions
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
return [tool] if tool else []
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
# if (
# "2.0" in model
# and "gemini-2.0-flash-thinking-exp" not in model
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
return settings.SAFETY_SETTINGS
def _build_payload(
request: ChatRequest,
messages: List[Dict[str, Any]],
instruction: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""构建请求payload"""
payload = {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k,
},
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
if request.max_tokens is not None:
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if request.model.endswith("-non-thinking"):
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
if request.model in settings.THINKING_BUDGET_MAP:
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
}
if (
instruction
and isinstance(instruction, dict)
and instruction.get("role") == "system"
and instruction.get("parts")
and not request.model.endswith("-image")
and not request.model.endswith("-image-generation")
):
payload["systemInstruction"] = instruction
return payload
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.message_converter = OpenAIMessageConverter()
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
"""从OpenAI响应块中提取文本内容"""
if not chunk.get("choices"):
return ""
choice = chunk["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
return choice["delta"]["content"]
return ""
def _create_char_openai_chunk(
self, original_chunk: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的OpenAI响应块"""
chunk_copy = json.loads(json.dumps(original_chunk))
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
chunk_copy["choices"][0]["delta"]["content"] = text
return chunk_copy
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
messages, instruction = self.message_converter.convert(request.messages)
payload = _build_payload(request, messages, instruction)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return await self._handle_normal_completion(request.model, payload, api_key)
async def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.generate_content(payload, model, api_key)
usage_metadata = response.get("usageMetadata", {})
is_success = True
status_code = 200
return self.response_handler.handle_response(
response,
model,
stream=False,
finish_reason="stop",
usage_metadata=usage_metadata,
)
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Normal API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-chat-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
async def _fake_stream_logic_impl(
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理伪流式 (fake stream) 的核心逻辑"""
logger.info(
f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint."
)
keep_sending_empty_data = True
async def send_empty_data_locally() -> AsyncGenerator[str, None]:
"""定期发送空数据以保持连接"""
while keep_sending_empty_data:
await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
if keep_sending_empty_data:
empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
yield f"data: {json.dumps(empty_chunk)}\n\n"
logger.debug("Sent empty data chunk for fake stream heartbeat.")
empty_data_generator = send_empty_data_locally()
api_response_task = asyncio.create_task(
self.api_client.generate_content(payload, model, api_key)
)
try:
while not api_response_task.done():
try:
next_empty_chunk = await asyncio.wait_for(
empty_data_generator.__anext__(), timeout=0.1
)
yield next_empty_chunk
except asyncio.TimeoutError:
pass
except (
StopAsyncIteration
):
break
response = await api_response_task
finally:
keep_sending_empty_data = False
if response and response.get("candidates"):
response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
yield f"data: {json.dumps(response)}\n\n"
logger.info(f"Sent full response content for fake stream: {model}")
else:
error_message = "Failed to get response from model"
if (
response and isinstance(response, dict) and response.get("error")
):
error_details = response.get("error")
if isinstance(error_details, dict):
error_message = error_details.get("message", error_message)
logger.error(
f"No candidates or error in response for fake stream model {model}: {response}"
)
error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
yield f"data: {json.dumps(error_chunk)}\n\n"
async def _real_stream_logic_impl(
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理真实流式 (real stream) 的核心逻辑"""
tool_call_flag = False
usage_metadata = None
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
if line.startswith("data:"):
chunk_str = line[6:]
if not chunk_str or chunk_str.isspace():
logger.debug(
f"Received empty data line for model {model}, skipping."
)
continue
try:
chunk = json.loads(chunk_str)
usage_metadata = chunk.get("usageMetadata", {})
except json.JSONDecodeError:
logger.error(
f"Failed to decode JSON from stream for model {model}: {chunk_str}"
)
continue
openai_chunk = self.response_handler.handle_response(
chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
)
if openai_chunk:
text = self._extract_text_from_openai_chunk(openai_chunk)
if text and settings.STREAM_OPTIMIZER_ENABLED:
async for (
optimized_chunk_data
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk_data
else:
if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
tool_call_flag = True
yield f"data: {json.dumps(openai_chunk)}\n\n"
if tool_call_flag:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n"
else:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n"
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑和假流式支持"""
retries = 0
max_retries = settings.MAX_RETRIES
is_success = False
status_code = None
final_api_key = api_key
while retries < max_retries:
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
current_attempt_key = final_api_key
try:
stream_generator = None
if settings.FAKE_STREAM_ENABLED:
logger.info(
f"Using fake stream logic for model: {model}, Attempt: {retries + 1}"
)
stream_generator = self._fake_stream_logic_impl(
model, payload, current_attempt_key
)
else:
logger.info(
f"Using real stream logic for model: {model}, Attempt: {retries + 1}"
)
stream_generator = self._real_stream_logic_impl(
model, payload, current_attempt_key
)
async for chunk_data in stream_generator:
yield chunk_data
yield "data: [DONE]\n\n"
logger.info(
f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}"
)
is_success = True
status_code = 200
break
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
)
match = re.search(r"status code (\\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
if isinstance(e, asyncio.TimeoutError):
status_code = 408
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
model_name=model,
error_type="openai-chat-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
)
if self.key_manager:
new_api_key = await self.key_manager.handle_api_failure(
current_attempt_key, retries
)
if new_api_key and new_api_key != current_attempt_key:
final_api_key = new_api_key
logger.info(
f"Switched to new API key for next attempt: {final_api_key}"
)
elif not new_api_key:
logger.error(
f"No valid API key available after {retries} retries, ceasing attempts for this request."
)
break
else:
logger.error(
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
)
break
if retries >= max_retries:
logger.error(
f"Max retries ({max_retries}) reached for streaming model {model}."
)
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=current_attempt_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
if not is_success:
logger.error(
f"Streaming failed permanently for model {model} after {retries} attempts."
)
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
yield "data: [DONE]\n\n"
async def create_image_chat_completion(
self, request: ChatRequest, api_key: str
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
image_generate_request.prompt = request.messages[-1]["content"]
image_res = self.image_create_service.generate_images_chat(
image_generate_request
)
if request.stream:
return self._handle_stream_image_completion(
request.model, image_res, api_key
)
else:
return await self._handle_normal_image_completion(
request.model, image_res, api_key
)
async def _handle_stream_image_completion(
self, model: str, image_data: str, api_key: str
) -> AsyncGenerator[str, None]:
logger.info(f"Starting stream image completion for model: {model}")
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
try:
if image_data:
openai_chunk = self.response_handler.handle_image_chat_response(
image_data, model, stream=True, finish_reason=None
)
if openai_chunk:
# 提取文本内容
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
# 如果没有文本内容(如图片URL等),整块输出
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
logger.info(
f"Stream image completion finished successfully for model: {model}"
)
is_success = True
status_code = 200
yield "data: [DONE]\n\n"
except Exception as e:
is_success = False
error_log_msg = f"Stream image completion failed for model {model}: {e}"
logger.error(error_log_msg)
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]},
)
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
yield "data: [DONE]\n\n"
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
async def _handle_normal_image_completion(
self, model: str, image_data: str, api_key: str
) -> Dict[str, Any]:
logger.info(f"Starting normal image completion for model: {model}")
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
result = None
try:
result = self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)
logger.info(
f"Normal image completion finished successfully for model: {model}"
)
is_success = True
status_code = 200
return result
except Exception as e:
is_success = False
error_log_msg = f"Normal image completion failed for model {model}: {e}"
logger.error(error_log_msg)
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-image-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg={"image_data_truncated": image_data[:1000]},
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)