# app/services/chat_service.py import asyncio import datetime import json import re import time from copy import deepcopy from typing import Any, AsyncGenerator, Dict, List, Optional, Union from app.config.config import settings from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS from app.database.services import ( add_error_log, add_request_log, ) from app.domain.openai_models import ChatRequest, ImageGenerationRequest from app.handler.message_converter import OpenAIMessageConverter from app.handler.response_handler import OpenAIResponseHandler from app.handler.stream_optimizer import openai_optimizer from app.log.logger import get_openai_logger from app.service.client.api_client import GeminiApiClient from app.service.image.image_create_service import ImageCreateService from app.service.key.key_manager import KeyManager logger = get_openai_logger() def _has_media_parts(contents: List[Dict[str, Any]]) -> bool: """判断消息是否包含图片、音频或视频部分 (inline_data)""" for content in contents: if content and "parts" in content and isinstance(content["parts"], list): for part in content["parts"]: if isinstance(part, dict) and "inline_data" in part: return True return False def _build_tools( request: ChatRequest, messages: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """构建工具""" tool = dict() model = request.model if ( settings.TOOLS_CODE_EXECUTION_ENABLED and not ( model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation") ) and not _has_media_parts(messages) ): tool["codeExecution"] = {} logger.debug("Code execution tool enabled.") elif _has_media_parts(messages): logger.debug("Code execution tool disabled due to media parts presence.") if model.endswith("-search"): tool["googleSearch"] = {} # 将 request 中的 tools 合并到 tools 中 if request.tools: function_declarations = [] for item in request.tools: if not item or not isinstance(item, dict): continue if item.get("type", "") == "function" and item.get("function"): function = deepcopy(item.get("function")) parameters = function.get("parameters", {}) if parameters.get("type") == "object" and not parameters.get( "properties", {} ): function.pop("parameters", None) function_declarations.append(function) if function_declarations: # 按照 function 的 name 去重 names, functions = set(), [] for fc in function_declarations: if fc.get("name") not in names: if fc.get("name")=="googleSearch": # cherry开启内置搜索时,添加googleSearch工具 tool["googleSearch"] = {} else: # 其他函数,添加到functionDeclarations中 names.add(fc.get("name")) functions.append(fc) tool["functionDeclarations"] = functions # 解决 "Tool use with function calling is unsupported" 问题 if tool.get("functionDeclarations"): tool.pop("googleSearch", None) tool.pop("codeExecution", None) return [tool] if tool else [] def _get_safety_settings(model: str) -> List[Dict[str, str]]: """获取安全设置""" # if ( # "2.0" in model # and "gemini-2.0-flash-thinking-exp" not in model # and "gemini-2.0-pro-exp" not in model # ): if model == "gemini-2.0-flash-exp": return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS return settings.SAFETY_SETTINGS def _build_payload( request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """构建请求payload""" payload = { "contents": messages, "generationConfig": { "temperature": request.temperature, "stopSequences": request.stop, "topP": request.top_p, "topK": request.top_k, }, "tools": _build_tools(request, messages), "safetySettings": _get_safety_settings(request.model), } if request.max_tokens is not None: payload["generationConfig"]["maxOutputTokens"] = request.max_tokens if request.model.endswith("-image") or request.model.endswith("-image-generation"): payload["generationConfig"]["responseModalities"] = ["Text", "Image"] if request.model.endswith("-non-thinking"): payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0} if request.model in settings.THINKING_BUDGET_MAP: payload["generationConfig"]["thinkingConfig"] = { "thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000) } if ( instruction and isinstance(instruction, dict) and instruction.get("role") == "system" and instruction.get("parts") and not request.model.endswith("-image") and not request.model.endswith("-image-generation") ): payload["systemInstruction"] = instruction return payload class OpenAIChatService: """聊天服务""" def __init__(self, base_url: str, key_manager: KeyManager = None): self.message_converter = OpenAIMessageConverter() self.response_handler = OpenAIResponseHandler(config=None) self.api_client = GeminiApiClient(base_url, settings.TIME_OUT) self.key_manager = key_manager self.image_create_service = ImageCreateService() def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str: """从OpenAI响应块中提取文本内容""" if not chunk.get("choices"): return "" choice = chunk["choices"][0] if "delta" in choice and "content" in choice["delta"]: return choice["delta"]["content"] return "" def _create_char_openai_chunk( self, original_chunk: Dict[str, Any], text: str ) -> Dict[str, Any]: """创建包含指定文本的OpenAI响应块""" chunk_copy = json.loads(json.dumps(original_chunk)) if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]: chunk_copy["choices"][0]["delta"]["content"] = text return chunk_copy async def create_chat_completion( self, request: ChatRequest, api_key: str, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """创建聊天完成""" messages, instruction = self.message_converter.convert(request.messages) payload = _build_payload(request, messages, instruction) if request.stream: return self._handle_stream_completion(request.model, payload, api_key) return await self._handle_normal_completion(request.model, payload, api_key) async def _handle_normal_completion( self, model: str, payload: Dict[str, Any], api_key: str ) -> Dict[str, Any]: """处理普通聊天完成""" start_time = time.perf_counter() request_datetime = datetime.datetime.now() is_success = False status_code = None response = None try: response = await self.api_client.generate_content(payload, model, api_key) usage_metadata = response.get("usageMetadata", {}) is_success = True status_code = 200 return self.response_handler.handle_response( response, model, stream=False, finish_reason="stop", usage_metadata=usage_metadata, ) except Exception as e: is_success = False error_log_msg = str(e) logger.error(f"Normal API call failed with error: {error_log_msg}") match = re.search(r"status code (\d+)", error_log_msg) if match: status_code = int(match.group(1)) else: status_code = 500 await add_error_log( gemini_key=api_key, model_name=model, error_type="openai-chat-non-stream", error_log=error_log_msg, error_code=status_code, request_msg=payload, ) raise e finally: end_time = time.perf_counter() latency_ms = int((end_time - start_time) * 1000) await add_request_log( model_name=model, api_key=api_key, is_success=is_success, status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, ) async def _fake_stream_logic_impl( self, model: str, payload: Dict[str, Any], api_key: str ) -> AsyncGenerator[str, None]: """处理伪流式 (fake stream) 的核心逻辑""" logger.info( f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint." ) keep_sending_empty_data = True async def send_empty_data_locally() -> AsyncGenerator[str, None]: """定期发送空数据以保持连接""" while keep_sending_empty_data: await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS) if keep_sending_empty_data: empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None) yield f"data: {json.dumps(empty_chunk)}\n\n" logger.debug("Sent empty data chunk for fake stream heartbeat.") empty_data_generator = send_empty_data_locally() api_response_task = asyncio.create_task( self.api_client.generate_content(payload, model, api_key) ) try: while not api_response_task.done(): try: next_empty_chunk = await asyncio.wait_for( empty_data_generator.__anext__(), timeout=0.1 ) yield next_empty_chunk except asyncio.TimeoutError: pass except ( StopAsyncIteration ): break response = await api_response_task finally: keep_sending_empty_data = False if response and response.get("candidates"): response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {})) yield f"data: {json.dumps(response)}\n\n" logger.info(f"Sent full response content for fake stream: {model}") else: error_message = "Failed to get response from model" if ( response and isinstance(response, dict) and response.get("error") ): error_details = response.get("error") if isinstance(error_details, dict): error_message = error_details.get("message", error_message) logger.error( f"No candidates or error in response for fake stream model {model}: {response}" ) error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None) yield f"data: {json.dumps(error_chunk)}\n\n" async def _real_stream_logic_impl( self, model: str, payload: Dict[str, Any], api_key: str ) -> AsyncGenerator[str, None]: """处理真实流式 (real stream) 的核心逻辑""" tool_call_flag = False usage_metadata = None async for line in self.api_client.stream_generate_content( payload, model, api_key ): if line.startswith("data:"): chunk_str = line[6:] if not chunk_str or chunk_str.isspace(): logger.debug( f"Received empty data line for model {model}, skipping." ) continue try: chunk = json.loads(chunk_str) usage_metadata = chunk.get("usageMetadata", {}) except json.JSONDecodeError: logger.error( f"Failed to decode JSON from stream for model {model}: {chunk_str}" ) continue openai_chunk = self.response_handler.handle_response( chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata ) if openai_chunk: text = self._extract_text_from_openai_chunk(openai_chunk) if text and settings.STREAM_OPTIMIZER_ENABLED: async for ( optimized_chunk_data ) in openai_optimizer.optimize_stream_output( text, lambda t: self._create_char_openai_chunk(openai_chunk, t), lambda c: f"data: {json.dumps(c)}\n\n", ): yield optimized_chunk_data else: if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"): tool_call_flag = True yield f"data: {json.dumps(openai_chunk)}\n\n" if tool_call_flag: yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n" else: yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n" async def _handle_stream_completion( self, model: str, payload: Dict[str, Any], api_key: str ) -> AsyncGenerator[str, None]: """处理流式聊天完成,添加重试逻辑和假流式支持""" retries = 0 max_retries = settings.MAX_RETRIES is_success = False status_code = None final_api_key = api_key while retries < max_retries: start_time = time.perf_counter() request_datetime = datetime.datetime.now() current_attempt_key = final_api_key try: stream_generator = None if settings.FAKE_STREAM_ENABLED: logger.info( f"Using fake stream logic for model: {model}, Attempt: {retries + 1}" ) stream_generator = self._fake_stream_logic_impl( model, payload, current_attempt_key ) else: logger.info( f"Using real stream logic for model: {model}, Attempt: {retries + 1}" ) stream_generator = self._real_stream_logic_impl( model, payload, current_attempt_key ) async for chunk_data in stream_generator: yield chunk_data yield "data: [DONE]\n\n" logger.info( f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}" ) is_success = True status_code = 200 break except Exception as e: retries += 1 is_success = False error_log_msg = str(e) logger.warning( f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}" ) match = re.search(r"status code (\\d+)", error_log_msg) if match: status_code = int(match.group(1)) else: if isinstance(e, asyncio.TimeoutError): status_code = 408 else: status_code = 500 await add_error_log( gemini_key=current_attempt_key, model_name=model, error_type="openai-chat-stream", error_log=error_log_msg, error_code=status_code, request_msg=payload, ) if self.key_manager: new_api_key = await self.key_manager.handle_api_failure( current_attempt_key, retries ) if new_api_key and new_api_key != current_attempt_key: final_api_key = new_api_key logger.info( f"Switched to new API key for next attempt: {final_api_key}" ) elif not new_api_key: logger.error( f"No valid API key available after {retries} retries, ceasing attempts for this request." ) break else: logger.error( "KeyManager not available, cannot switch API key. Ceasing attempts for this request." ) break if retries >= max_retries: logger.error( f"Max retries ({max_retries}) reached for streaming model {model}." ) finally: end_time = time.perf_counter() latency_ms = int((end_time - start_time) * 1000) await add_request_log( model_name=model, api_key=current_attempt_key, is_success=is_success, status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, ) if not is_success: logger.error( f"Streaming failed permanently for model {model} after {retries} attempts." ) yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n" yield "data: [DONE]\n\n" async def create_image_chat_completion( self, request: ChatRequest, api_key: str ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: image_generate_request = ImageGenerationRequest() image_generate_request.prompt = request.messages[-1]["content"] image_res = self.image_create_service.generate_images_chat( image_generate_request ) if request.stream: return self._handle_stream_image_completion( request.model, image_res, api_key ) else: return await self._handle_normal_image_completion( request.model, image_res, api_key ) async def _handle_stream_image_completion( self, model: str, image_data: str, api_key: str ) -> AsyncGenerator[str, None]: logger.info(f"Starting stream image completion for model: {model}") start_time = time.perf_counter() request_datetime = datetime.datetime.now() is_success = False status_code = None try: if image_data: openai_chunk = self.response_handler.handle_image_chat_response( image_data, model, stream=True, finish_reason=None ) if openai_chunk: # 提取文本内容 text = self._extract_text_from_openai_chunk(openai_chunk) if text: # 使用流式输出优化器处理文本输出 async for ( optimized_chunk ) in openai_optimizer.optimize_stream_output( text, lambda t: self._create_char_openai_chunk(openai_chunk, t), lambda c: f"data: {json.dumps(c)}\n\n", ): yield optimized_chunk else: # 如果没有文本内容(如图片URL等),整块输出 yield f"data: {json.dumps(openai_chunk)}\n\n" yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" logger.info( f"Stream image completion finished successfully for model: {model}" ) is_success = True status_code = 200 yield "data: [DONE]\n\n" except Exception as e: is_success = False error_log_msg = f"Stream image completion failed for model {model}: {e}" logger.error(error_log_msg) status_code = 500 await add_error_log( gemini_key=api_key, model_name=model, error_type="openai-image-stream", error_log=error_log_msg, error_code=status_code, request_msg={"image_data_truncated": image_data[:1000]}, ) yield f"data: {json.dumps({'error': error_log_msg})}\n\n" yield "data: [DONE]\n\n" finally: end_time = time.perf_counter() latency_ms = int((end_time - start_time) * 1000) logger.info( f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}" ) await add_request_log( model_name=model, api_key=api_key, is_success=is_success, status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, ) async def _handle_normal_image_completion( self, model: str, image_data: str, api_key: str ) -> Dict[str, Any]: logger.info(f"Starting normal image completion for model: {model}") start_time = time.perf_counter() request_datetime = datetime.datetime.now() is_success = False status_code = None result = None try: result = self.response_handler.handle_image_chat_response( image_data, model, stream=False, finish_reason="stop" ) logger.info( f"Normal image completion finished successfully for model: {model}" ) is_success = True status_code = 200 return result except Exception as e: is_success = False error_log_msg = f"Normal image completion failed for model {model}: {e}" logger.error(error_log_msg) status_code = 500 await add_error_log( gemini_key=api_key, model_name=model, error_type="openai-image-non-stream", error_log=error_log_msg, error_code=status_code, request_msg={"image_data_truncated": image_data[:1000]}, ) raise e finally: end_time = time.perf_counter() latency_ms = int((end_time - start_time) * 1000) logger.info( f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}" ) await add_request_log( model_name=model, api_key=api_key, is_success=is_success, status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, )