Spaces:
Running
Running
File size: 6,699 Bytes
76b9762 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from app.config.config import settings
from app.core.security import SecurityService
from app.domain.openai_models import (
ChatRequest,
EmbeddingRequest,
ImageGenerationRequest,
TTSRequest,
)
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.tts.tts_service import TTSService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
router = APIRouter()
logger = get_openai_logger()
security_service = SecurityService()
model_service = ModelService()
embedding_service = EmbeddingService()
image_create_service = ImageCreateService()
tts_service = TTSService()
async def get_key_manager():
return await get_key_manager_instance()
async def get_next_working_key_wrapper(
key_manager: KeyManager = Depends(get_key_manager),
):
return await key_manager.get_next_working_key()
async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
"""获取OpenAI聊天服务实例"""
return OpenAIChatService(settings.BASE_URL, key_manager)
async def get_tts_service():
"""获取TTS服务实例"""
return tts_service
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
operation_name = "list_models"
async with handle_route_errors(logger, operation_name):
logger.info("Handling models list request")
api_key = await key_manager.get_first_valid_key()
logger.info(f"Using API key: {api_key}")
return await model_service.get_gemini_openai_models(api_key)
@router.post("/v1/chat/completions")
@router.post("/hf/v1/chat/completions")
@RetryHandler(key_arg="api_key")
async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
):
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
operation_name = "chat_completion"
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
current_api_key = api_key
if is_image_chat:
current_api_key = await key_manager.get_paid_key()
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling chat completion request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {current_api_key}")
if not await model_service.check_model_support(request.model):
raise HTTPException(
status_code=400, detail=f"Model {request.model} is not supported"
)
if is_image_chat:
response = await chat_service.create_image_chat_completion(request, current_api_key)
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
else:
response = await chat_service.create_chat_completion(request, current_api_key)
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
@router.post("/v1/images/generations")
@router.post("/hf/v1/images/generations")
async def generate_image(
request: ImageGenerationRequest,
_=Depends(security_service.verify_authorization),
):
"""处理 OpenAI 图像生成请求。"""
operation_name = "generate_image"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling image generation request for prompt: {request.prompt}")
response = image_create_service.generate_images(request)
return response
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager),
):
"""处理 OpenAI 文本嵌入请求。"""
operation_name = "embedding"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling embedding request for model: {request.model}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
response = await embedding_service.create_embedding(
input_text=request.input, model=request.model, api_key=api_key
)
return response
@router.get("/v1/keys/list")
@router.get("/hf/v1/keys/list")
async def get_keys_list(
_=Depends(security_service.verify_auth_token),
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
operation_name = "get_keys_list"
async with handle_route_errors(logger, operation_name):
logger.info("Handling keys list request")
keys_status = await key_manager.get_keys_by_status()
return {
"status": "success",
"data": {
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
},
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
}
@router.post("/v1/audio/speech")
@router.post("/hf/v1/audio/speech")
async def text_to_speech(
request: TTSRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
tts_service: TTSService = Depends(get_tts_service),
):
"""处理 OpenAI TTS 请求。"""
operation_name = "text_to_speech"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling TTS request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
audio_data = await tts_service.create_tts(request, api_key)
return Response(content=audio_data, media_type="audio/wav")
|