GeminiBalance / app /router /gemini_routes.py
CatPtain's picture
Upload 77 files
76b9762 verified
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
import asyncio
from app.config.config import settings
from app.log.logger import get_gemini_logger
from app.core.security import SecurityService
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
from app.core.constants import API_VERSION
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
logger = get_gemini_logger()
security_service = SecurityService()
model_service = ModelService()
async def get_key_manager():
"""获取密钥管理器实例"""
return await get_key_manager_instance()
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
"""获取下一个可用的API密钥"""
return await key_manager.get_next_working_key()
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
"""获取Gemini聊天服务实例"""
return GeminiChatService(settings.BASE_URL, key_manager)
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(
_=Depends(security_service.verify_key_or_goog_api_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
operation_name = "list_gemini_models"
logger.info("-" * 50 + operation_name + "-" * 50)
logger.info("Handling Gemini models list request")
try:
api_key = await key_manager.get_first_valid_key()
if not api_key:
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
logger.info(f"Using API key: {api_key}")
models_data = await model_service.get_gemini_models(api_key)
if not models_data or "models" not in models_data:
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
models_json = deepcopy(models_data)
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
def add_derived_model(base_name, suffix, display_suffix):
model = model_mapping.get(base_name)
if not model:
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
return
item = deepcopy(model)
item["name"] = f"models/{base_name}{suffix}"
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
if settings.SEARCH_MODELS:
for name in settings.SEARCH_MODELS:
add_derived_model(name, "-search", " For Search")
if settings.IMAGE_MODELS:
for name in settings.IMAGE_MODELS:
add_derived_model(name, "-image", " For Image")
if settings.THINKING_MODELS:
for name in settings.THINKING_MODELS:
add_derived_model(name, "-non-thinking", " Non Thinking")
logger.info("Gemini models list request successful")
return models_json
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.error(f"Error getting Gemini models list: {str(e)}")
raise HTTPException(
status_code=500, detail="Internal server error while fetching Gemini models list"
) from e
@router.post("/models/{model_name}:generateContent")
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
):
"""处理 Gemini 非流式内容生成请求。"""
operation_name = "gemini_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
@router.post("/models/{model_name}:streamGenerateContent")
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
):
"""处理 Gemini 流式内容生成请求。"""
operation_name = "gemini_stream_generate_content"
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
@router.post("/reset-all-fail-counts")
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
logger.info(f"Received reset request with key_type: {key_type}")
try:
# 获取分类后的密钥
keys_by_status = await key_manager.get_keys_by_status()
valid_keys = keys_by_status.get("valid_keys", {})
invalid_keys = keys_by_status.get("invalid_keys", {})
# 根据类型选择要重置的密钥
keys_to_reset = []
if key_type == "valid":
keys_to_reset = list(valid_keys.keys())
logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}")
elif key_type == "invalid":
keys_to_reset = list(invalid_keys.keys())
logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}")
else:
# 重置所有密钥
await key_manager.reset_failure_counts()
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
# 批量重置指定类型的密钥
for key in keys_to_reset:
await key_manager.reset_key_failure_count(key)
return JSONResponse({
"success": True,
"message": f"{key_type}密钥的失败计数已重置",
"reset_count": len(keys_to_reset)
})
except Exception as e:
logger.error(f"Failed to reset key failure counts: {str(e)}")
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
@router.post("/reset-selected-fail-counts")
async def reset_selected_key_fail_counts(
request: ResetSelectedKeysRequest,
key_manager: KeyManager = Depends(get_key_manager)
):
"""批量重置选定Gemini API密钥的失败计数"""
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
keys_to_reset = request.keys
key_type = request.key_type
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
if not keys_to_reset:
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
reset_count = 0
errors = []
try:
for key in keys_to_reset:
try:
result = await key_manager.reset_key_failure_count(key)
if result:
reset_count += 1
else:
logger.warning(f"Key not found during selective reset: {key}")
except Exception as key_error:
logger.error(f"Error resetting key {key}: {str(key_error)}")
errors.append(f"Key {key}: {str(key_error)}")
if errors:
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
final_success = reset_count > 0
status_code = 207 if final_success and errors else 500
return JSONResponse({
"success": final_success,
"message": error_message,
"reset_count": reset_count
}, status_code=status_code)
return JSONResponse({
"success": True,
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
"reset_count": reset_count
})
except Exception as e:
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
@router.post("/reset-fail-count/{api_key}")
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
"""重置指定Gemini API密钥的失败计数"""
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
logger.info(f"Resetting failure count for API key: {api_key}")
try:
result = await key_manager.reset_key_failure_count(api_key)
if result:
return JSONResponse({"success": True, "message": "失败计数已重置"})
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
except Exception as e:
logger.error(f"Failed to reset key failure count: {str(e)}")
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
@router.post("/verify-key/{api_key}")
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
"""验证Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
logger.info("Verifying API key validity")
try:
gemini_request = GeminiRequest(
contents=[
GeminiContent(
role="user",
parts=[{"text": "hi"}],
)
],
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
)
response = await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
)
if response:
return JSONResponse({"status": "valid"})
except Exception as e:
logger.error(f"Key verification failed: {str(e)}")
async with key_manager.failure_count_lock:
if api_key in key_manager.key_failure_counts:
key_manager.key_failure_counts[api_key] += 1
logger.warning(f"Verification exception for key: {api_key}, incrementing failure count")
return JSONResponse({"status": "invalid", "error": str(e)})
@router.post("/verify-selected-keys")
async def verify_selected_keys(
request: VerifySelectedKeysRequest,
chat_service: GeminiChatService = Depends(get_chat_service),
key_manager: KeyManager = Depends(get_key_manager)
):
"""批量验证选定Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
keys_to_verify = request.keys
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
if not keys_to_verify:
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
successful_keys = []
failed_keys = {}
async def _verify_single_key(api_key: str):
"""内部函数,用于验证单个密钥并处理异常"""
nonlocal successful_keys, failed_keys
try:
gemini_request = GeminiRequest(
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
)
await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
)
successful_keys.append(api_key)
return api_key, "valid", None
except Exception as e:
error_message = str(e)
logger.warning(f"Key verification failed for {api_key}: {error_message}")
async with key_manager.failure_count_lock:
if api_key in key_manager.key_failure_counts:
key_manager.key_failure_counts[api_key] += 1
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
else:
key_manager.key_failure_counts[api_key] = 1
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
failed_keys[api_key] = error_message
return api_key, "invalid", error_message
tasks = [_verify_single_key(key) for key in keys_to_verify]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
elif result:
if not isinstance(result, Exception) and result:
key, status, error = result
elif isinstance(result, Exception):
logger.error(f"Task execution error during bulk verification: {result}")
valid_count = len(successful_keys)
invalid_count = len(failed_keys)
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
if failed_keys:
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
return JSONResponse({
"success": True,
"message": message,
"successful_keys": successful_keys,
"failed_keys": failed_keys,
"valid_count": valid_count,
"invalid_count": invalid_count
})
else:
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
return JSONResponse({
"success": True,
"message": message,
"successful_keys": successful_keys,
"failed_keys": {},
"valid_count": valid_count,
"invalid_count": 0
})