import datetime import json import re import time from typing import Any, AsyncGenerator, Dict, Union from app.config.config import settings from app.database.services import ( add_error_log, add_request_log, ) from app.domain.openai_models import ChatRequest, ImageGenerationRequest from app.service.client.api_client import OpenaiApiClient from app.service.key.key_manager import KeyManager from app.log.logger import get_openai_compatible_logger logger = get_openai_compatible_logger() class OpenAICompatiableService: def __init__(self, base_url: str, key_manager: KeyManager = None): self.key_manager = key_manager self.base_url = base_url self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT) async def get_models(self, api_key: str) -> Dict[str, Any]: return await self.api_client.get_models(api_key) async def create_chat_completion( self, request: ChatRequest, api_key: str, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """创建聊天完成""" request_dict = request.model_dump() # 移除值为null的 request_dict = {k: v for k, v in request_dict.items() if v is not None} del request_dict["top_k"] # 删除top_k参数,目前不支持该参数 if request.stream: return self._handle_stream_completion(request.model, request_dict, api_key) return await self._handle_normal_completion(request.model, request_dict, api_key) async def generate_images( self, request: ImageGenerationRequest, ) -> Dict[str, Any]: """生成图片""" request_dict = request.model_dump() # 移除值为null的 request_dict = {k: v for k, v in request_dict.items() if v is not None} api_key = settings.PAID_KEY return await self.api_client.generate_images(request_dict, api_key) async def create_embeddings( self, input_text: str, model: str, api_key: str, ) -> Dict[str, Any]: """创建嵌入""" return await self.api_client.create_embeddings(input_text, model, api_key) async def _handle_normal_completion( self, model: str, request: dict, api_key: str ) -> Dict[str, Any]: """处理普通聊天完成""" start_time = time.perf_counter() request_datetime = datetime.datetime.now() is_success = False status_code = None response = None try: response = await self.api_client.generate_content(request, api_key) is_success = True status_code = 200 return response except Exception as e: is_success = False error_log_msg = str(e) logger.error(f"Normal API call failed with error: {error_log_msg}") match = re.search(r"status code (\d+)", error_log_msg) if match: status_code = int(match.group(1)) else: status_code = 500 await add_error_log( gemini_key=api_key, model_name=model, error_type="openai-compatiable-non-stream", error_log=error_log_msg, error_code=status_code, request_msg=request, ) raise e finally: end_time = time.perf_counter() latency_ms = int((end_time - start_time) * 1000) await add_request_log( model_name=model, api_key=api_key, is_success=is_success, status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, ) async def _handle_stream_completion( self, model: str, payload: dict, api_key: str ) -> AsyncGenerator[str, None]: """处理流式聊天完成,添加重试逻辑""" retries = 0 max_retries = settings.MAX_RETRIES is_success = False status_code = None final_api_key = api_key while retries < max_retries: start_time = time.perf_counter() request_datetime = datetime.datetime.now() current_attempt_key = api_key final_api_key = current_attempt_key try: async for line in self.api_client.stream_generate_content( payload, current_attempt_key ): if line.startswith("data:"): # print(line) yield line + "\n\n" logger.info("Streaming completed successfully") is_success = True status_code = 200 break except Exception as e: retries += 1 is_success = False error_log_msg = str(e) logger.warning( f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}" ) match = re.search(r"status code (\d+)", error_log_msg) if match: status_code = int(match.group(1)) else: status_code = 500 await add_error_log( gemini_key=current_attempt_key, model_name=model, error_type="openai-compatiable-stream", error_log=error_log_msg, error_code=status_code, request_msg=payload, ) if self.key_manager: api_key = await self.key_manager.handle_api_failure( current_attempt_key, retries ) if api_key: logger.info(f"Switched to new API key: {api_key}") else: logger.error( f"No valid API key available after {retries} retries." ) break else: logger.error("KeyManager not available for retry logic.") break if retries >= max_retries: logger.error(f"Max retries ({max_retries}) reached for streaming.") break finally: end_time = time.perf_counter() latency_ms = int((end_time - start_time) * 1000) await add_request_log( model_name=model, api_key=final_api_key, is_success=is_success, status_code=status_code, latency_ms=latency_ms, request_time=request_datetime, ) if not is_success and retries >= max_retries: yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n" yield "data: [DONE]\n\n"