Spaces:
Running
Running
from fastapi import APIRouter, Depends, HTTPException, Response | |
from fastapi.responses import StreamingResponse | |
from app.config.config import settings | |
from app.core.security import SecurityService | |
from app.domain.openai_models import ( | |
ChatRequest, | |
EmbeddingRequest, | |
ImageGenerationRequest, | |
TTSRequest, | |
) | |
from app.handler.retry_handler import RetryHandler | |
from app.handler.error_handler import handle_route_errors | |
from app.log.logger import get_openai_logger | |
from app.service.chat.openai_chat_service import OpenAIChatService | |
from app.service.embedding.embedding_service import EmbeddingService | |
from app.service.image.image_create_service import ImageCreateService | |
from app.service.tts.tts_service import TTSService | |
from app.service.key.key_manager import KeyManager, get_key_manager_instance | |
from app.service.model.model_service import ModelService | |
router = APIRouter() | |
logger = get_openai_logger() | |
security_service = SecurityService() | |
model_service = ModelService() | |
embedding_service = EmbeddingService() | |
image_create_service = ImageCreateService() | |
tts_service = TTSService() | |
async def get_key_manager(): | |
return await get_key_manager_instance() | |
async def get_next_working_key_wrapper( | |
key_manager: KeyManager = Depends(get_key_manager), | |
): | |
return await key_manager.get_next_working_key() | |
async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)): | |
"""获取OpenAI聊天服务实例""" | |
return OpenAIChatService(settings.BASE_URL, key_manager) | |
async def get_tts_service(): | |
"""获取TTS服务实例""" | |
return tts_service | |
async def list_models( | |
_=Depends(security_service.verify_authorization), | |
key_manager: KeyManager = Depends(get_key_manager), | |
): | |
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。""" | |
operation_name = "list_models" | |
async with handle_route_errors(logger, operation_name): | |
logger.info("Handling models list request") | |
api_key = await key_manager.get_first_valid_key() | |
logger.info(f"Using API key: {api_key}") | |
return await model_service.get_gemini_openai_models(api_key) | |
async def chat_completion( | |
request: ChatRequest, | |
_=Depends(security_service.verify_authorization), | |
api_key: str = Depends(get_next_working_key_wrapper), | |
key_manager: KeyManager = Depends(get_key_manager), | |
chat_service: OpenAIChatService = Depends(get_openai_chat_service), | |
): | |
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。""" | |
operation_name = "chat_completion" | |
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat" | |
current_api_key = api_key | |
if is_image_chat: | |
current_api_key = await key_manager.get_paid_key() | |
async with handle_route_errors(logger, operation_name): | |
logger.info(f"Handling chat completion request for model: {request.model}") | |
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") | |
logger.info(f"Using API key: {current_api_key}") | |
if not await model_service.check_model_support(request.model): | |
raise HTTPException( | |
status_code=400, detail=f"Model {request.model} is not supported" | |
) | |
if is_image_chat: | |
response = await chat_service.create_image_chat_completion(request, current_api_key) | |
if request.stream: | |
return StreamingResponse(response, media_type="text/event-stream") | |
return response | |
else: | |
response = await chat_service.create_chat_completion(request, current_api_key) | |
if request.stream: | |
return StreamingResponse(response, media_type="text/event-stream") | |
return response | |
async def generate_image( | |
request: ImageGenerationRequest, | |
_=Depends(security_service.verify_authorization), | |
): | |
"""处理 OpenAI 图像生成请求。""" | |
operation_name = "generate_image" | |
async with handle_route_errors(logger, operation_name): | |
logger.info(f"Handling image generation request for prompt: {request.prompt}") | |
response = image_create_service.generate_images(request) | |
return response | |
async def embedding( | |
request: EmbeddingRequest, | |
_=Depends(security_service.verify_authorization), | |
key_manager: KeyManager = Depends(get_key_manager), | |
): | |
"""处理 OpenAI 文本嵌入请求。""" | |
operation_name = "embedding" | |
async with handle_route_errors(logger, operation_name): | |
logger.info(f"Handling embedding request for model: {request.model}") | |
api_key = await key_manager.get_next_working_key() | |
logger.info(f"Using API key: {api_key}") | |
response = await embedding_service.create_embedding( | |
input_text=request.input, model=request.model, api_key=api_key | |
) | |
return response | |
async def get_keys_list( | |
_=Depends(security_service.verify_auth_token), | |
key_manager: KeyManager = Depends(get_key_manager), | |
): | |
"""获取有效和无效的API key列表 (需要管理 Token 认证)。""" | |
operation_name = "get_keys_list" | |
async with handle_route_errors(logger, operation_name): | |
logger.info("Handling keys list request") | |
keys_status = await key_manager.get_keys_by_status() | |
return { | |
"status": "success", | |
"data": { | |
"valid_keys": keys_status["valid_keys"], | |
"invalid_keys": keys_status["invalid_keys"], | |
}, | |
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]), | |
} | |
async def text_to_speech( | |
request: TTSRequest, | |
_=Depends(security_service.verify_authorization), | |
api_key: str = Depends(get_next_working_key_wrapper), | |
tts_service: TTSService = Depends(get_tts_service), | |
): | |
"""处理 OpenAI TTS 请求。""" | |
operation_name = "text_to_speech" | |
async with handle_route_errors(logger, operation_name): | |
logger.info(f"Handling TTS request for model: {request.model}") | |
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") | |
logger.info(f"Using API key: {api_key}") | |
audio_data = await tts_service.create_tts(request, api_key) | |
return Response(content=audio_data, media_type="audio/wav") | |