diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a5ba95bc30f1b7518b2acadfdac6eae62c444762 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +files/image.png filter=lfs diff=lfs merge=lfs -text +files/image1.png filter=lfs diff=lfs merge=lfs -text +files/image2.png filter=lfs diff=lfs merge=lfs -text +files/image3.png filter=lfs diff=lfs merge=lfs -text +files/image4.png filter=lfs diff=lfs merge=lfs -text +files/image5.png filter=lfs diff=lfs merge=lfs -text +files/image6.png filter=lfs diff=lfs merge=lfs -text +files/image7.png filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..7ef23a975ce8a1d03619bb1d71d6372353fe6add --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.10-slim + +WORKDIR /app + +# 复制所需文件到容器中 +COPY ./requirements.txt /app +COPY ./VERSION /app + +RUN pip install --no-cache-dir -r requirements.txt +COPY ./app /app/app +ENV API_KEYS='["your_api_key_1"]' +ENV ALLOWED_TOKENS='["your_token_1"]' +ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta +ENV TOOLS_CODE_EXECUTION_ENABLED=false +ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]' +ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]' +ENV URL_NORMALIZATION_ENABLED=false + +# Expose port +EXPOSE 7860 + +# Run the application +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--no-access-log"] diff --git a/app/config/config.py b/app/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d24055d8c927cc0c4abc50bd2463ff00a28c7da1 --- /dev/null +++ b/app/config/config.py @@ -0,0 +1,479 @@ +""" +应用程序配置模块 +""" + +import datetime +import json +from typing import Any, Dict, List, Type + +from pydantic import ValidationError, ValidationInfo, field_validator +from pydantic_settings import BaseSettings +from sqlalchemy import insert, select, update + +from app.core.constants import ( + API_VERSION, + DEFAULT_CREATE_IMAGE_MODEL, + DEFAULT_FILTER_MODELS, + DEFAULT_MODEL, + DEFAULT_SAFETY_SETTINGS, + DEFAULT_STREAM_CHUNK_SIZE, + DEFAULT_STREAM_LONG_TEXT_THRESHOLD, + DEFAULT_STREAM_MAX_DELAY, + DEFAULT_STREAM_MIN_DELAY, + DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, + DEFAULT_TIMEOUT, + MAX_RETRIES, +) +from app.log.logger import Logger + + +class Settings(BaseSettings): + # 数据库配置 + DATABASE_TYPE: str = "mysql" # sqlite 或 mysql + SQLITE_DATABASE: str = "default_db" + MYSQL_HOST: str = "" + MYSQL_PORT: int = 3306 + MYSQL_USER: str = "" + MYSQL_PASSWORD: str = "" + MYSQL_DATABASE: str = "" + MYSQL_SOCKET: str = "" + + # 验证 MySQL 配置 + @field_validator( + "MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE" + ) + def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any: + if info.data.get("DATABASE_TYPE") == "mysql": + if v is None or v == "": + raise ValueError( + "MySQL configuration is required when DATABASE_TYPE is 'mysql'" + ) + return v + + # API相关配置 + API_KEYS: List[str] + ALLOWED_TOKENS: List[str] + BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}" + AUTH_TOKEN: str = "" + MAX_FAILURES: int = 3 + TEST_MODEL: str = DEFAULT_MODEL + TIME_OUT: int = DEFAULT_TIMEOUT + MAX_RETRIES: int = MAX_RETRIES + PROXIES: List[str] = [] + PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理 + VERTEX_API_KEYS: List[str] = [] + VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google" + + # 智能路由配置 + URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能 + + # 模型相关配置 + SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"] + IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"] + FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS + TOOLS_CODE_EXECUTION_ENABLED: bool = False + SHOW_SEARCH_LINK: bool = True + SHOW_THINKING_PROCESS: bool = True + THINKING_MODELS: List[str] = [] + THINKING_BUDGET_MAP: Dict[str, float] = {} + + # TTS相关配置 + TTS_MODEL: str = "gemini-2.5-flash-preview-tts" + TTS_VOICE_NAME: str = "Zephyr" + TTS_SPEED: str = "normal" + + # 图像生成相关配置 + PAID_KEY: str = "" + CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL + UPLOAD_PROVIDER: str = "smms" + SMMS_SECRET_TOKEN: str = "" + PICGO_API_KEY: str = "" + CLOUDFLARE_IMGBED_URL: str = "" + CLOUDFLARE_IMGBED_AUTH_CODE: str = "" + + # 流式输出优化器配置 + STREAM_OPTIMIZER_ENABLED: bool = False + STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY + STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY + STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD + STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD + STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE + + # 假流式配置 (Fake Streaming Configuration) + FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出 + FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒) + + # 调度器配置 + CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时 + TIMEZONE: str = "Asia/Shanghai" # 默认时区 + + # github + GITHUB_REPO_OWNER: str = "snailyp" + GITHUB_REPO_NAME: str = "gemini-balance" + + # 日志配置 + LOG_LEVEL: str = "INFO" + AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True + AUTO_DELETE_ERROR_LOGS_DAYS: int = 7 + AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False + AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30 + SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS + + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # 设置默认AUTH_TOKEN(如果未提供) + if not self.AUTH_TOKEN and self.ALLOWED_TOKENS: + self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] + + +# 创建全局配置实例 +settings = Settings() + + +def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any: + """尝试将数据库字符串值解析为目标 Python 类型""" + from app.log.logger import get_config_logger + + logger = get_config_logger() + try: + # 处理 List[str] + if target_type == List[str]: + try: + parsed = json.loads(db_value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + except json.JSONDecodeError: + return [item.strip() for item in db_value.split(",") if item.strip()] + logger.warning( + f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list." + ) + return [item.strip() for item in db_value.split(",") if item.strip()] + # 处理 Dict[str, float] + elif target_type == Dict[str, float]: + parsed_dict = {} + try: + parsed = json.loads(db_value) + if isinstance(parsed, dict): + parsed_dict = {str(k): float(v) for k, v in parsed.items()} + else: + logger.warning( + f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}" + ) + except (json.JSONDecodeError, ValueError, TypeError) as e1: + if isinstance(e1, json.JSONDecodeError) and "'" in db_value: + logger.warning( + f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}" + ) + try: + corrected_db_value = db_value.replace("'", '"') + parsed = json.loads(corrected_db_value) + if isinstance(parsed, dict): + parsed_dict = {str(k): float(v) for k, v in parsed.items()} + else: + logger.warning( + f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}" + ) + except (json.JSONDecodeError, ValueError, TypeError) as e2: + logger.error( + f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict." + ) + else: + logger.error( + f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict." + ) + return parsed_dict + # 处理 List[Dict[str, str]] + elif target_type == List[Dict[str, str]]: + try: + parsed = json.loads(db_value) + if isinstance(parsed, list): + # 验证列表中的每个元素是否为字典,并且键和值都是字符串 + valid = all( + isinstance(item, dict) + and all(isinstance(k, str) for k in item.keys()) + and all(isinstance(v, str) for v in item.values()) + for item in parsed + ) + if valid: + return parsed + else: + logger.warning( + f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}" + ) + return [] + else: + logger.warning( + f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}" + ) + return [] + except json.JSONDecodeError: + logger.error( + f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list." + ) + return [] + except Exception as e: + logger.error( + f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list." + ) + return [] + # 处理 bool + elif target_type == bool: + return db_value.lower() in ("true", "1", "yes", "on") + # 处理 int + elif target_type == int: + return int(db_value) + # 处理 float + elif target_type == float: + return float(db_value) + # 默认为 str 或其他 pydantic 能直接处理的类型 + else: + return db_value + except (ValueError, TypeError, json.JSONDecodeError) as e: + logger.warning( + f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value." + ) + return db_value # 解析失败则返回原始字符串 + + +async def sync_initial_settings(): + """ + 应用启动时同步配置: + 1. 从数据库加载设置。 + 2. 将数据库设置合并到内存 settings (数据库优先)。 + 3. 将最终的内存 settings 同步回数据库。 + """ + from app.log.logger import get_config_logger + + logger = get_config_logger() + # 延迟导入以避免循环依赖和确保数据库连接已初始化 + from app.database.connection import database + from app.database.models import Settings as SettingsModel + + global settings + logger.info("Starting initial settings synchronization...") + + if not database.is_connected: + try: + await database.connect() + logger.info("Database connection established for initial sync.") + except Exception as e: + logger.error( + f"Failed to connect to database for initial settings sync: {e}. Skipping sync." + ) + return + + try: + # 1. 从数据库加载设置 + db_settings_raw: List[Dict[str, Any]] = [] + try: + query = select(SettingsModel.key, SettingsModel.value) + results = await database.fetch_all(query) + db_settings_raw = [ + {"key": row["key"], "value": row["value"]} for row in results + ] + logger.info(f"Fetched {len(db_settings_raw)} settings from database.") + except Exception as e: + logger.error( + f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings." + ) + # 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库 + + db_settings_map: Dict[str, str] = { + s["key"]: s["value"] for s in db_settings_raw + } + + # 2. 将数据库设置合并到内存 settings (数据库优先) + updated_in_memory = False + + for key, db_value in db_settings_map.items(): + if key == "DATABASE_TYPE": + logger.debug( + f"Skipping update of '{key}' in memory from database. " + "This setting is controlled by environment/dotenv." + ) + continue + if hasattr(settings, key): + target_type = Settings.__annotations__.get(key) + if target_type: + try: + parsed_db_value = _parse_db_value(key, db_value, target_type) + memory_value = getattr(settings, key) + + # 比较解析后的值和内存中的值 + # 注意:对于列表等复杂类型,直接比较可能不够健壮,但这里简化处理 + if parsed_db_value != memory_value: + # 检查类型是否匹配,以防解析函数返回了不兼容的类型 + type_match = False + if target_type == List[str] and isinstance( + parsed_db_value, list + ): + type_match = True + elif target_type == Dict[str, float] and isinstance( + parsed_db_value, dict + ): + type_match = True + elif target_type not in ( + List[str], + Dict[str, float], + ) and isinstance(parsed_db_value, target_type): + type_match = True + + if type_match: + setattr(settings, key, parsed_db_value) + logger.debug( + f"Updated setting '{key}' in memory from database value ({target_type})." + ) + updated_in_memory = True + else: + logger.warning( + f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update." + ) + + except Exception as e: + logger.error( + f"Error processing database setting for key '{key}': {e}" + ) + else: + logger.warning( + f"Database setting '{key}' not found in Settings model definition. Ignoring." + ) + + # 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐) + if updated_in_memory: + try: + # 重新加载以确保类型转换和验证 + settings = Settings(**settings.model_dump()) + logger.info( + "Settings object re-validated after merging database values." + ) + except ValidationError as e: + logger.error( + f"Validation error after merging database settings: {e}. Settings might be inconsistent." + ) + + # 3. 将最终的内存 settings 同步回数据库 + final_memory_settings = settings.model_dump() + settings_to_update: List[Dict[str, Any]] = [] + settings_to_insert: List[Dict[str, Any]] = [] + now = datetime.datetime.now(datetime.timezone.utc) + + existing_db_keys = set(db_settings_map.keys()) + + for key, value in final_memory_settings.items(): + if key == "DATABASE_TYPE": + logger.debug( + f"Skipping synchronization of '{key}' to database. " + "This setting is controlled by environment/dotenv." + ) + continue + + # 序列化值为字符串或 JSON 字符串 + if isinstance(value, (list, dict)): + db_value = json.dumps( + value, ensure_ascii=False + ) + elif isinstance(value, bool): + db_value = str(value).lower() + elif value is None: + db_value = "" + else: + db_value = str(value) + + data = { + "key": key, + "value": db_value, + "description": f"{key} configuration setting", + "updated_at": now, + } + + if key in existing_db_keys: + # 仅当值与数据库中的不同时才更新 + if db_settings_map[key] != db_value: + settings_to_update.append(data) + else: + # 如果键不在数据库中,则插入 + data["created_at"] = now + settings_to_insert.append(data) + + # 在事务中执行批量插入和更新 + if settings_to_insert or settings_to_update: + try: + async with database.transaction(): + if settings_to_insert: + # 获取现有描述以避免覆盖 + query_existing = select( + SettingsModel.key, SettingsModel.description + ).where( + SettingsModel.key.in_( + [s["key"] for s in settings_to_insert] + ) + ) + existing_desc = { + row["key"]: row["description"] + for row in await database.fetch_all(query_existing) + } + for item in settings_to_insert: + item["description"] = existing_desc.get( + item["key"], item["description"] + ) + + query_insert = insert(SettingsModel).values(settings_to_insert) + await database.execute(query=query_insert) + logger.info( + f"Synced (inserted) {len(settings_to_insert)} settings to database." + ) + + if settings_to_update: + # 获取现有描述以避免覆盖 + query_existing = select( + SettingsModel.key, SettingsModel.description + ).where( + SettingsModel.key.in_( + [s["key"] for s in settings_to_update] + ) + ) + existing_desc = { + row["key"]: row["description"] + for row in await database.fetch_all(query_existing) + } + + for setting_data in settings_to_update: + setting_data["description"] = existing_desc.get( + setting_data["key"], setting_data["description"] + ) + query_update = ( + update(SettingsModel) + .where(SettingsModel.key == setting_data["key"]) + .values( + value=setting_data["value"], + description=setting_data["description"], + updated_at=setting_data["updated_at"], + ) + ) + await database.execute(query=query_update) + logger.info( + f"Synced (updated) {len(settings_to_update)} settings to database." + ) + except Exception as e: + logger.error( + f"Failed to sync settings to database during startup: {str(e)}" + ) + else: + logger.info( + "No setting changes detected between memory and database during initial sync." + ) + + # 刷新日志等级 + Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL")) + + except Exception as e: + logger.error(f"An unexpected error occurred during initial settings sync: {e}") + finally: + if database.is_connected: + try: + pass + except Exception as e: + logger.error(f"Error disconnecting database after initial sync: {e}") + + logger.info("Initial settings synchronization finished.") diff --git a/app/core/application.py b/app/core/application.py new file mode 100644 index 0000000000000000000000000000000000000000..16b074fc0d794254e0272624c9268c0b4b24428d --- /dev/null +++ b/app/core/application.py @@ -0,0 +1,153 @@ +from contextlib import asynccontextmanager +from pathlib import Path + +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates + +from app.config.config import settings, sync_initial_settings +from app.database.connection import connect_to_db, disconnect_from_db +from app.database.initialization import initialize_database +from app.exception.exceptions import setup_exception_handlers +from app.log.logger import get_application_logger +from app.middleware.middleware import setup_middlewares +from app.router.routes import setup_routers +from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler +from app.service.key.key_manager import get_key_manager_instance +from app.service.update.update_service import check_for_updates +from app.utils.helpers import get_current_version + +logger = get_application_logger() + +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +STATIC_DIR = PROJECT_ROOT / "app" / "static" +TEMPLATES_DIR = PROJECT_ROOT / "app" / "templates" + +# 初始化模板引擎,并添加全局变量 +templates = Jinja2Templates(directory="app/templates") + + +# 定义一个函数来更新模板全局变量 +def update_template_globals(app: FastAPI, update_info: dict): + # Jinja2Templates 实例没有直接更新全局变量的方法 + # 我们需要在请求上下文中传递这些变量,或者修改 Jinja 环境 + # 更简单的方法是将其存储在 app.state 中,并在渲染时传递 + app.state.update_info = update_info + logger.info(f"Update info stored in app.state: {update_info}") + + +# --- Helper functions for lifespan --- +async def _setup_database_and_config(app_settings): + """Initializes database, syncs settings, and initializes KeyManager.""" + initialize_database() + logger.info("Database initialized successfully") + await connect_to_db() + await sync_initial_settings() + await get_key_manager_instance(app_settings.API_KEYS, app_settings.VERTEX_API_KEYS) + logger.info("Database, config sync, and KeyManager initialized successfully") + + +async def _shutdown_database(): + """Disconnects from the database.""" + await disconnect_from_db() + + +def _start_scheduler(): + """Starts the background scheduler.""" + try: + start_scheduler() + logger.info("Scheduler started successfully.") + except Exception as e: + logger.error(f"Failed to start scheduler: {e}") + + +def _stop_scheduler(): + """Stops the background scheduler.""" + stop_scheduler() + + +async def _perform_update_check(app: FastAPI): + """Checks for updates and stores the info in app.state.""" + update_available, latest_version, error_message = await check_for_updates() + current_version = get_current_version() + update_info = { + "update_available": update_available, + "latest_version": latest_version, + "error_message": error_message, + "current_version": current_version, + } + if not hasattr(app, "state"): + from starlette.datastructures import State + + app.state = State() + app.state.update_info = update_info + logger.info(f"Update check completed. Info: {update_info}") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Manages the application startup and shutdown events. + + Args: + app: FastAPI应用实例 + """ + logger.info("Application starting up...") + try: + await _setup_database_and_config(settings) + await _perform_update_check(app) + _start_scheduler() + + except Exception as e: + logger.critical( + f"Critical error during application startup: {str(e)}", exc_info=True + ) + + yield + + logger.info("Application shutting down...") + _stop_scheduler() + await _shutdown_database() + + +def create_app() -> FastAPI: + """ + 创建并配置FastAPI应用程序实例 + + Returns: + FastAPI: 配置好的FastAPI应用程序实例 + """ + + # 创建FastAPI应用 + current_version = get_current_version() + app = FastAPI( + title="Gemini Balance API", + description="Gemini API代理服务,支持负载均衡和密钥管理", + version=current_version, + lifespan=lifespan, + ) + + if not hasattr(app, "state"): + from starlette.datastructures import State + + app.state = State() + app.state.update_info = { + "update_available": False, + "latest_version": None, + "error_message": "Initializing...", + "current_version": current_version, + } + + # 配置静态文件 + app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") + + # 配置中间件 + setup_middlewares(app) + + # 配置异常处理器 + setup_exception_handlers(app) + + # 配置路由 + setup_routers(app) + + return app diff --git a/app/core/constants.py b/app/core/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..21d72aaf9271caf3b43669b870e15e23a95a90a8 --- /dev/null +++ b/app/core/constants.py @@ -0,0 +1,79 @@ +""" +常量定义模块 +""" + +# API相关常量 +API_VERSION = "v1beta" +DEFAULT_TIMEOUT = 300 # 秒 +MAX_RETRIES = 3 # 最大重试次数 + +# 模型相关常量 +SUPPORTED_ROLES = ["user", "model", "system"] +DEFAULT_MODEL = "gemini-1.5-flash" +DEFAULT_TEMPERATURE = 0.7 +DEFAULT_MAX_TOKENS = 8192 +DEFAULT_TOP_P = 0.9 +DEFAULT_TOP_K = 40 +DEFAULT_FILTER_MODELS = [ + "gemini-1.0-pro-vision-latest", + "gemini-pro-vision", + "chat-bison-001", + "text-bison-001", + "embedding-gecko-001" + ] +DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002" + +# 图像生成相关常量 +VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"] + +# 上传提供商 +UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"] +DEFAULT_UPLOAD_PROVIDER = "smms" + +# 流式输出相关常量 +DEFAULT_STREAM_MIN_DELAY = 0.016 +DEFAULT_STREAM_MAX_DELAY = 0.024 +DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10 +DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50 +DEFAULT_STREAM_CHUNK_SIZE = 5 + +# 正则表达式模式 +IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)' +DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)' + +# Audio/Video Settings +SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"] +SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"] +MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload +MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit + +# Optional: Define MIME type mappings if needed, or handle directly in converter +AUDIO_FORMAT_TO_MIMETYPE = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "flac": "audio/flac", + "ogg": "audio/ogg", +} + +VIDEO_FORMAT_TO_MIMETYPE = { + "mp4": "video/mp4", + "mov": "video/quicktime", + "avi": "video/x-msvideo", + "webm": "video/webm", +} + +GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}, + ] + +DEFAULT_SAFETY_SETTINGS = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ] \ No newline at end of file diff --git a/app/core/security.py b/app/core/security.py new file mode 100644 index 0000000000000000000000000000000000000000..eebad69146b5893c111c4d0ff3d8b1bf3c2e29c9 --- /dev/null +++ b/app/core/security.py @@ -0,0 +1,90 @@ +from typing import Optional + +from fastapi import Header, HTTPException + +from app.config.config import settings +from app.log.logger import get_security_logger + +logger = get_security_logger() + + +def verify_auth_token(token: str) -> bool: + return token == settings.AUTH_TOKEN + + +class SecurityService: + + async def verify_key(self, key: str): + if key not in settings.ALLOWED_TOKENS and key != settings.AUTH_TOKEN: + logger.error("Invalid key") + raise HTTPException(status_code=401, detail="Invalid key") + return key + + async def verify_authorization( + self, authorization: Optional[str] = Header(None) + ) -> str: + if not authorization: + logger.error("Missing Authorization header") + raise HTTPException(status_code=401, detail="Missing Authorization header") + + if not authorization.startswith("Bearer "): + logger.error("Invalid Authorization header format") + raise HTTPException( + status_code=401, detail="Invalid Authorization header format" + ) + + token = authorization.replace("Bearer ", "") + if token not in settings.ALLOWED_TOKENS and token != settings.AUTH_TOKEN: + logger.error("Invalid token") + raise HTTPException(status_code=401, detail="Invalid token") + + return token + + async def verify_goog_api_key( + self, x_goog_api_key: Optional[str] = Header(None) + ) -> str: + """验证Google API Key""" + if not x_goog_api_key: + logger.error("Missing x-goog-api-key header") + raise HTTPException(status_code=401, detail="Missing x-goog-api-key header") + + if ( + x_goog_api_key not in settings.ALLOWED_TOKENS + and x_goog_api_key != settings.AUTH_TOKEN + ): + logger.error("Invalid x-goog-api-key") + raise HTTPException(status_code=401, detail="Invalid x-goog-api-key") + + return x_goog_api_key + + async def verify_auth_token( + self, authorization: Optional[str] = Header(None) + ) -> str: + if not authorization: + logger.error("Missing auth_token header") + raise HTTPException(status_code=401, detail="Missing auth_token header") + token = authorization.replace("Bearer ", "") + if token != settings.AUTH_TOKEN: + logger.error("Invalid auth_token") + raise HTTPException(status_code=401, detail="Invalid auth_token") + + return token + + async def verify_key_or_goog_api_key( + self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None) + ) -> str: + """验证URL中的key或请求头中的x-goog-api-key""" + # 如果URL中的key有效,直接返回 + if key in settings.ALLOWED_TOKENS or key == settings.AUTH_TOKEN: + return key + + # 否则检查请求头中的x-goog-api-key + if not x_goog_api_key: + logger.error("Invalid key and missing x-goog-api-key header") + raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header") + + if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN: + logger.error("Invalid key and invalid x-goog-api-key") + raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key") + + return x_goog_api_key \ No newline at end of file diff --git a/app/database/__init__.py b/app/database/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c166ad8fbf63035ed29702fbc0a58a98fca67e0 --- /dev/null +++ b/app/database/__init__.py @@ -0,0 +1,3 @@ +""" +数据库模块 +""" diff --git a/app/database/connection.py b/app/database/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..2a088d5d030290d12d38ddcbb5097e7c5beb8ab7 --- /dev/null +++ b/app/database/connection.py @@ -0,0 +1,71 @@ +""" +数据库连接池模块 +""" +from pathlib import Path +from urllib.parse import quote_plus +from databases import Database +from sqlalchemy import create_engine, MetaData +from sqlalchemy.ext.declarative import declarative_base + +from app.config.config import settings +from app.log.logger import get_database_logger + +logger = get_database_logger() + +# 数据库URL +if settings.DATABASE_TYPE == "sqlite": + # 确保 data 目录存在 + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + db_path = data_dir / settings.SQLITE_DATABASE + DATABASE_URL = f"sqlite:///{db_path}" +elif settings.DATABASE_TYPE == "mysql": + if settings.MYSQL_SOCKET: + DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}" + else: + DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}" +else: + raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.") + +# 创建数据库引擎 +# pool_pre_ping=True: 在从连接池获取连接前执行简单的 "ping" 测试,确保连接有效 +engine = create_engine(DATABASE_URL, pool_pre_ping=True) + +# 创建元数据对象 +metadata = MetaData() + +# 创建基类 +Base = declarative_base(metadata=metadata) + +# 创建数据库连接池,并配置连接池参数,在sqlite中不使用连接池 +# min_size/max_size: 连接池的最小/最大连接数 +# pool_recycle=3600: 连接在池中允许存在的最大秒数(生命周期)。 +# 设置为 3600 秒(1小时),确保在 MySQL 默认的 wait_timeout (通常8小时) 或其他网络超时之前回收连接。 +# 如果遇到连接失效问题,可以尝试调低此值,使其小于实际的 wait_timeout 或网络超时时间。 +# databases 库会自动处理连接失效后的重连尝试。 +if settings.DATABASE_TYPE == "sqlite": + database = Database(DATABASE_URL) +else: + database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800) + +async def connect_to_db(): + """ + 连接到数据库 + """ + try: + await database.connect() + logger.info(f"Connected to {settings.DATABASE_TYPE}") + except Exception as e: + logger.error(f"Failed to connect to database: {str(e)}") + raise + + +async def disconnect_from_db(): + """ + 断开数据库连接 + """ + try: + await database.disconnect() + logger.info(f"Disconnected from {settings.DATABASE_TYPE}") + except Exception as e: + logger.error(f"Failed to disconnect from database: {str(e)}") diff --git a/app/database/initialization.py b/app/database/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e88632d8daffd52c415d922973fe4578779323 --- /dev/null +++ b/app/database/initialization.py @@ -0,0 +1,77 @@ +""" +数据库初始化模块 +""" +from dotenv import dotenv_values + +from sqlalchemy import inspect +from sqlalchemy.orm import Session + +from app.database.connection import engine, Base +from app.database.models import Settings +from app.log.logger import get_database_logger + +logger = get_database_logger() + + +def create_tables(): + """ + 创建数据库表 + """ + try: + # 创建所有表 + Base.metadata.create_all(engine) + logger.info("Database tables created successfully") + except Exception as e: + logger.error(f"Failed to create database tables: {str(e)}") + raise + + +def import_env_to_settings(): + """ + 将.env文件中的配置项导入到t_settings表中 + """ + try: + # 获取.env文件中的所有配置项 + env_values = dotenv_values(".env") + + # 获取检查器 + inspector = inspect(engine) + + # 检查t_settings表是否存在 + if "t_settings" in inspector.get_table_names(): + # 使用Session进行数据库操作 + with Session(engine) as session: + # 获取所有现有的配置项 + current_settings = {setting.key: setting for setting in session.query(Settings).all()} + + # 遍历所有配置项 + for key, value in env_values.items(): + # 检查配置项是否已存在 + if key not in current_settings: + # 插入配置项 + new_setting = Settings(key=key, value=value) + session.add(new_setting) + logger.info(f"Inserted setting: {key}") + + # 提交事务 + session.commit() + + logger.info("Environment variables imported to settings table successfully") + except Exception as e: + logger.error(f"Failed to import environment variables to settings table: {str(e)}") + raise + + +def initialize_database(): + """ + 初始化数据库 + """ + try: + # 创建表 + create_tables() + + # 导入环境变量 + import_env_to_settings() + except Exception as e: + logger.error(f"Failed to initialize database: {str(e)}") + raise diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c33ae6a7fb21fd19e9445ef47b1afe4668dec7bd --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,62 @@ +""" +数据库模型模块 +""" +import datetime +from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean + +from app.database.connection import Base + + +class Settings(Base): + """ + 设置表,对应.env中的配置项 + """ + __tablename__ = "t_settings" + + id = Column(Integer, primary_key=True, autoincrement=True) + key = Column(String(100), nullable=False, unique=True, comment="配置项键名") + value = Column(Text, nullable=True, comment="配置项值") + description = Column(String(255), nullable=True, comment="配置项描述") + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" + + +class ErrorLog(Base): + """ + 错误日志表 + """ + __tablename__ = "t_error_logs" + + id = Column(Integer, primary_key=True, autoincrement=True) + gemini_key = Column(String(100), nullable=True, comment="Gemini API密钥") + model_name = Column(String(100), nullable=True, comment="模型名称") + error_type = Column(String(50), nullable=True, comment="错误类型") + error_log = Column(Text, nullable=True, comment="错误日志") + error_code = Column(Integer, nullable=True, comment="错误代码") + request_msg = Column(JSON, nullable=True, comment="请求消息") + request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间") + + def __repr__(self): + return f"" + + +class RequestLog(Base): + """ + API 请求日志表 + """ + + __tablename__ = "t_request_log" + + id = Column(Integer, primary_key=True, autoincrement=True) + request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间") + model_name = Column(String(100), nullable=True, comment="模型名称") + api_key = Column(String(100), nullable=True, comment="使用的API密钥") + is_success = Column(Boolean, nullable=False, comment="请求是否成功") + status_code = Column(Integer, nullable=True, comment="API响应状态码") + latency_ms = Column(Integer, nullable=True, comment="请求耗时(毫秒)") + + def __repr__(self): + return f"" diff --git a/app/database/services.py b/app/database/services.py new file mode 100644 index 0000000000000000000000000000000000000000..893b60881cb9e8abdbdf95c2da853103daa9d22a --- /dev/null +++ b/app/database/services.py @@ -0,0 +1,429 @@ +""" +数据库服务模块 +""" +from typing import List, Optional, Dict, Any, Union +from datetime import datetime +from sqlalchemy import func, desc, asc, select, insert, update, delete +import json +from app.database.connection import database +from app.database.models import Settings, ErrorLog, RequestLog +from app.log.logger import get_database_logger + +logger = get_database_logger() + + +async def get_all_settings() -> List[Dict[str, Any]]: + """ + 获取所有设置 + + Returns: + List[Dict[str, Any]]: 设置列表 + """ + try: + query = select(Settings) + result = await database.fetch_all(query) + return [dict(row) for row in result] + except Exception as e: + logger.error(f"Failed to get all settings: {str(e)}") + raise + + +async def get_setting(key: str) -> Optional[Dict[str, Any]]: + """ + 获取指定键的设置 + + Args: + key: 设置键名 + + Returns: + Optional[Dict[str, Any]]: 设置信息,如果不存在则返回None + """ + try: + query = select(Settings).where(Settings.key == key) + result = await database.fetch_one(query) + return dict(result) if result else None + except Exception as e: + logger.error(f"Failed to get setting {key}: {str(e)}") + raise + + +async def update_setting(key: str, value: str, description: Optional[str] = None) -> bool: + """ + 更新设置 + + Args: + key: 设置键名 + value: 设置值 + description: 设置描述 + + Returns: + bool: 是否更新成功 + """ + try: + # 检查设置是否存在 + setting = await get_setting(key) + + if setting: + # 更新设置 + query = ( + update(Settings) + .where(Settings.key == key) + .values( + value=value, + description=description if description else setting["description"], + updated_at=datetime.now() + ) + ) + await database.execute(query) + logger.info(f"Updated setting: {key}") + return True + else: + # 插入设置 + query = ( + insert(Settings) + .values( + key=key, + value=value, + description=description, + created_at=datetime.now(), + updated_at=datetime.now() + ) + ) + await database.execute(query) + logger.info(f"Inserted setting: {key}") + return True + except Exception as e: + logger.error(f"Failed to update setting {key}: {str(e)}") + return False + + +async def add_error_log( + gemini_key: Optional[str] = None, + model_name: Optional[str] = None, + error_type: Optional[str] = None, + error_log: Optional[str] = None, + error_code: Optional[int] = None, + request_msg: Optional[Union[Dict[str, Any], str]] = None +) -> bool: + """ + 添加错误日志 + + Args: + gemini_key: Gemini API密钥 + error_log: 错误日志 + error_code: 错误代码 (例如 HTTP 状态码) + request_msg: 请求消息 + + Returns: + bool: 是否添加成功 + """ + try: + # 如果request_msg是字典,则转换为JSON字符串 + if isinstance(request_msg, dict): + request_msg_json = request_msg + elif isinstance(request_msg, str): + try: + request_msg_json = json.loads(request_msg) + except json.JSONDecodeError: + request_msg_json = {"message": request_msg} + else: + request_msg_json = None + + # 插入错误日志 + query = ( + insert(ErrorLog) + .values( + gemini_key=gemini_key, + error_type=error_type, + error_log=error_log, + model_name=model_name, + error_code=error_code, + request_msg=request_msg_json, + request_time=datetime.now() + ) + ) + await database.execute(query) + logger.info(f"Added error log for key: {gemini_key}") + return True + except Exception as e: + logger.error(f"Failed to add error log: {str(e)}") + return False + + +async def get_error_logs( + limit: int = 20, + offset: int = 0, + key_search: Optional[str] = None, + error_search: Optional[str] = None, + error_code_search: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + sort_by: str = 'id', + sort_order: str = 'desc' +) -> List[Dict[str, Any]]: + """ + 获取错误日志,支持搜索、日期过滤和排序 + + Args: + limit (int): 限制数量 + offset (int): 偏移量 + key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配) + error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配) + error_code_search (Optional[str]): 错误码搜索词 (精确匹配) + start_date (Optional[datetime]): 开始日期时间 + end_date (Optional[datetime]): 结束日期时间 + sort_by (str): 排序字段 (例如 'id', 'request_time') + sort_order (str): 排序顺序 ('asc' or 'desc') + + Returns: + List[Dict[str, Any]]: 错误日志列表 + """ + try: + query = select( + ErrorLog.id, + ErrorLog.gemini_key, + ErrorLog.model_name, + ErrorLog.error_type, + ErrorLog.error_log, + ErrorLog.error_code, + ErrorLog.request_time + ) + + if key_search: + query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%")) + if error_search: + query = query.where( + (ErrorLog.error_type.ilike(f"%{error_search}%")) | + (ErrorLog.error_log.ilike(f"%{error_search}%")) + ) + if start_date: + query = query.where(ErrorLog.request_time >= start_date) + if end_date: + query = query.where(ErrorLog.request_time < end_date) + if error_code_search: + try: + error_code_int = int(error_code_search) + query = query.where(ErrorLog.error_code == error_code_int) + except ValueError: + logger.warning(f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter.") + + sort_column = getattr(ErrorLog, sort_by, ErrorLog.id) + if sort_order.lower() == 'asc': + query = query.order_by(asc(sort_column)) + else: + query = query.order_by(desc(sort_column)) + + query = query.limit(limit).offset(offset) + + result = await database.fetch_all(query) + return [dict(row) for row in result] + except Exception as e: + logger.exception(f"Failed to get error logs with filters: {str(e)}") + raise + + +async def get_error_logs_count( + key_search: Optional[str] = None, + error_search: Optional[str] = None, + error_code_search: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None +) -> int: + """ + 获取符合条件的错误日志总数 + + Args: + key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配) + error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配) + error_code_search (Optional[str]): 错误码搜索词 (精确匹配) + start_date (Optional[datetime]): 开始日期时间 + end_date (Optional[datetime]): 结束日期时间 + + Returns: + int: 日志总数 + """ + try: + query = select(func.count()).select_from(ErrorLog) + + if key_search: + query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%")) + if error_search: + query = query.where( + (ErrorLog.error_type.ilike(f"%{error_search}%")) | + (ErrorLog.error_log.ilike(f"%{error_search}%")) + ) + if start_date: + query = query.where(ErrorLog.request_time >= start_date) + if end_date: + query = query.where(ErrorLog.request_time < end_date) + if error_code_search: + try: + error_code_int = int(error_code_search) + query = query.where(ErrorLog.error_code == error_code_int) + except ValueError: + logger.warning(f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter.") + + + count_result = await database.fetch_one(query) + return count_result[0] if count_result else 0 + except Exception as e: + logger.exception(f"Failed to count error logs with filters: {str(e)}") + raise + + +# 新增函数:获取单条错误日志详情 +async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]: + """ + 根据 ID 获取单个错误日志的详细信息 + + Args: + log_id (int): 错误日志的 ID + + Returns: + Optional[Dict[str, Any]]: 包含日志详细信息的字典,如果未找到则返回 None + """ + try: + query = select(ErrorLog).where(ErrorLog.id == log_id) + result = await database.fetch_one(query) + if result: + # 将 request_msg (JSONB) 转换为字符串以便在 API 中返回 + log_dict = dict(result) + if 'request_msg' in log_dict and log_dict['request_msg'] is not None: + # 确保即使是 None 或非 JSON 数据也能处理 + try: + log_dict['request_msg'] = json.dumps(log_dict['request_msg'], ensure_ascii=False, indent=2) + except TypeError: + log_dict['request_msg'] = str(log_dict['request_msg']) + return log_dict + else: + return None + except Exception as e: + logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}") + raise + + +async def delete_error_logs_by_ids(log_ids: List[int]) -> int: + """ + 根据提供的 ID 列表批量删除错误日志 (异步)。 + + Args: + log_ids: 要删除的错误日志 ID 列表。 + + Returns: + int: 实际删除的日志数量。 + """ + if not log_ids: + return 0 + try: + # 使用 databases 执行删除 + query = delete(ErrorLog).where(ErrorLog.id.in_(log_ids)) + # execute 返回受影响的行数,但 databases 库的 execute 不直接返回 rowcount + # 我们需要先查询是否存在,或者依赖数据库约束/触发器(如果适用) + # 或者,我们可以执行删除并假设成功,除非抛出异常 + # 为了简单起见,我们执行删除并记录日志,不精确返回删除数量 + # 如果需要精确数量,需要先执行 SELECT COUNT(*) + await database.execute(query) + # 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量 + # 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试 + logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}") + return len(log_ids) # 返回尝试删除的数量 + except Exception as e: + # 数据库连接或执行错误 + logger.error(f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True) + raise + +async def delete_error_log_by_id(log_id: int) -> bool: + """ + 根据 ID 删除单个错误日志 (异步)。 + + Args: + log_id: 要删除的错误日志 ID。 + + Returns: + bool: 如果成功删除返回 True,否则返回 False。 + """ + try: + # 先检查是否存在 (可选,但更明确) + check_query = select(ErrorLog.id).where(ErrorLog.id == log_id) + exists = await database.fetch_one(check_query) + + if not exists: + logger.warning(f"Attempted to delete non-existent error log with ID: {log_id}") + return False + + # 执行删除 + delete_query = delete(ErrorLog).where(ErrorLog.id == log_id) + await database.execute(delete_query) + logger.info(f"Successfully deleted error log with ID: {log_id}") + return True + except Exception as e: + logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True) + raise + + +async def delete_all_error_logs() -> int: + """ + 删除所有错误日志条目。 + + Returns: + int: 被删除的错误日志数量。 + """ + try: + # 1. 获取删除前的总数 + count_query = select(func.count()).select_from(ErrorLog) + total_to_delete = await database.fetch_val(count_query) + + if total_to_delete == 0: + logger.info("No error logs found to delete.") + return 0 + + # 2. 执行删除操作 + delete_query = delete(ErrorLog) + await database.execute(delete_query) + + logger.info(f"Successfully deleted all {total_to_delete} error logs.") + return total_to_delete + except Exception as e: + logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True) + raise + + +# 新增函数:添加请求日志 +async def add_request_log( + model_name: Optional[str], + api_key: Optional[str], + is_success: bool, + status_code: Optional[int] = None, + latency_ms: Optional[int] = None, + request_time: Optional[datetime] = None +) -> bool: + """ + 添加 API 请求日志 + + Args: + model_name: 模型名称 + api_key: 使用的 API 密钥 + is_success: 请求是否成功 + status_code: API 响应状态码 + latency_ms: 请求耗时(毫秒) + request_time: 请求发生时间 (如果为 None, 则使用当前时间) + + Returns: + bool: 是否添加成功 + """ + try: + log_time = request_time if request_time else datetime.now() + + query = insert(RequestLog).values( + request_time=log_time, + model_name=model_name, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms + ) + await database.execute(query) + return True + except Exception as e: + logger.error(f"Failed to add request log: {str(e)}") + return False diff --git a/app/domain/gemini_models.py b/app/domain/gemini_models.py new file mode 100644 index 0000000000000000000000000000000000000000..0efcac9e29bf21b9bf49119e6c34bae5f3846464 --- /dev/null +++ b/app/domain/gemini_models.py @@ -0,0 +1,79 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P + + +class SafetySetting(BaseModel): + category: Optional[ + Literal[ + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_CIVIC_INTEGRITY", + ] + ] = None + threshold: Optional[ + Literal[ + "HARM_BLOCK_THRESHOLD_UNSPECIFIED", + "BLOCK_LOW_AND_ABOVE", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_ONLY_HIGH", + "BLOCK_NONE", + "OFF", + ] + ] = None + + +class GenerationConfig(BaseModel): + stopSequences: Optional[List[str]] = None + responseMimeType: Optional[str] = None + responseSchema: Optional[Dict[str, Any]] = None + candidateCount: Optional[int] = 1 + maxOutputTokens: Optional[int] = None + temperature: Optional[float] = DEFAULT_TEMPERATURE + topP: Optional[float] = DEFAULT_TOP_P + topK: Optional[int] = DEFAULT_TOP_K + presencePenalty: Optional[float] = None + frequencyPenalty: Optional[float] = None + responseLogprobs: Optional[bool] = None + logprobs: Optional[int] = None + thinkingConfig: Optional[Dict[str, Any]] = None + + +class SystemInstruction(BaseModel): + role: Optional[str] = "system" + parts: Union[List[Dict[str, Any]], Dict[str, Any]] + + +class GeminiContent(BaseModel): + role: Optional[str] = None + parts: List[Dict[str, Any]] + + +class GeminiRequest(BaseModel): + contents: List[GeminiContent] = [] + tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = [] + safetySettings: Optional[List[SafetySetting]] = Field( + default=None, alias="safety_settings" + ) + generationConfig: Optional[GenerationConfig] = Field( + default=None, alias="generation_config" + ) + systemInstruction: Optional[SystemInstruction] = Field( + default=None, alias="system_instruction" + ) + + class Config: + populate_by_name = True + + +class ResetSelectedKeysRequest(BaseModel): + keys: List[str] + key_type: str + + +class VerifySelectedKeysRequest(BaseModel): + keys: List[str] diff --git a/app/domain/image_models.py b/app/domain/image_models.py new file mode 100644 index 0000000000000000000000000000000000000000..29875cc522e6b04e96819cd780339bb5ee9c5aa4 --- /dev/null +++ b/app/domain/image_models.py @@ -0,0 +1,20 @@ +from typing import Union + + +class ImageMetadata: + def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None): + self.width = width + self.height = height + self.filename = filename + self.size = size + self.url = url + self.delete_url = delete_url +class UploadResponse: + def __init__(self, success: bool, code: str, message: str, data: ImageMetadata): + self.success = success + self.code = code + self.message = message + self.data = data +class ImageUploader: + def upload(self, file: bytes, filename: str) -> UploadResponse: + raise NotImplementedError diff --git a/app/domain/openai_models.py b/app/domain/openai_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b3077adaf0484bbfa97e851ee926667a302533b0 --- /dev/null +++ b/app/domain/openai_models.py @@ -0,0 +1,42 @@ +from pydantic import BaseModel +from typing import Any, Dict, List, Optional, Union + +from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P + + +class ChatRequest(BaseModel): + messages: List[dict] + model: str = DEFAULT_MODEL + temperature: Optional[float] = DEFAULT_TEMPERATURE + stream: Optional[bool] = False + max_tokens: Optional[int] = None + top_p: Optional[float] = DEFAULT_TOP_P + top_k: Optional[int] = DEFAULT_TOP_K + stop: Optional[Union[List[str],str]] = None + reasoning_effort: Optional[str] = None + tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = [] + tool_choice: Optional[str] = None + response_format: Optional[dict] = None + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] + model: str = "text-embedding-004" + encoding_format: Optional[str] = "float" + + +class ImageGenerationRequest(BaseModel): + model: str = "imagen-3.0-generate-002" + prompt: str = "" + n: int = 1 + size: Optional[str] = "1024x1024" + quality: Optional[str] = None + style: Optional[str] = None + response_format: Optional[str] = "url" + + +class TTSRequest(BaseModel): + model: str = "gemini-2.5-flash-preview-tts" + input: str + voice: str = "Kore" + response_format: Optional[str] = "wav" diff --git a/app/exception/exceptions.py b/app/exception/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9fb307c62aea51ebbfa742ab02e4f5088d5cdb --- /dev/null +++ b/app/exception/exceptions.py @@ -0,0 +1,140 @@ +""" +异常处理模块,定义应用程序中使用的自定义异常和异常处理器 +""" + +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette.exceptions import HTTPException as StarletteHTTPException + +from app.log.logger import get_exceptions_logger + +logger = get_exceptions_logger() + + +class APIError(Exception): + """API错误基类""" + + def __init__(self, status_code: int, detail: str, error_code: str = None): + self.status_code = status_code + self.detail = detail + self.error_code = error_code or "api_error" + super().__init__(self.detail) + + +class AuthenticationError(APIError): + """认证错误""" + + def __init__(self, detail: str = "Authentication failed"): + super().__init__( + status_code=401, detail=detail, error_code="authentication_error" + ) + + +class AuthorizationError(APIError): + """授权错误""" + + def __init__(self, detail: str = "Not authorized to access this resource"): + super().__init__( + status_code=403, detail=detail, error_code="authorization_error" + ) + + +class ResourceNotFoundError(APIError): + """资源未找到错误""" + + def __init__(self, detail: str = "Resource not found"): + super().__init__( + status_code=404, detail=detail, error_code="resource_not_found" + ) + + +class ModelNotSupportedError(APIError): + """模型不支持错误""" + + def __init__(self, model: str): + super().__init__( + status_code=400, + detail=f"Model {model} is not supported", + error_code="model_not_supported", + ) + + +class APIKeyError(APIError): + """API密钥错误""" + + def __init__(self, detail: str = "Invalid or expired API key"): + super().__init__(status_code=401, detail=detail, error_code="api_key_error") + + +class ServiceUnavailableError(APIError): + """服务不可用错误""" + + def __init__(self, detail: str = "Service temporarily unavailable"): + super().__init__( + status_code=503, detail=detail, error_code="service_unavailable" + ) + + +def setup_exception_handlers(app: FastAPI) -> None: + """ + 设置应用程序的异常处理器 + + Args: + app: FastAPI应用程序实例 + """ + + @app.exception_handler(APIError) + async def api_error_handler(request: Request, exc: APIError): + """处理API错误""" + logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})") + return JSONResponse( + status_code=exc.status_code, + content={"error": {"code": exc.error_code, "message": exc.detail}}, + ) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request: Request, exc: StarletteHTTPException): + """处理HTTP异常""" + logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})") + return JSONResponse( + status_code=exc.status_code, + content={"error": {"code": "http_error", "message": exc.detail}}, + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler( + request: Request, exc: RequestValidationError + ): + """处理请求验证错误""" + error_details = [] + for error in exc.errors(): + error_details.append( + {"loc": error["loc"], "msg": error["msg"], "type": error["type"]} + ) + + logger.error(f"Validation Error: {error_details}") + return JSONResponse( + status_code=422, + content={ + "error": { + "code": "validation_error", + "message": "Request validation failed", + "details": error_details, + } + }, + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request: Request, exc: Exception): + """处理通用异常""" + logger.exception(f"Unhandled Exception: {str(exc)}") + return JSONResponse( + status_code=500, + content={ + "error": { + "code": "internal_server_error", + "message": "An unexpected error occurred", + } + }, + ) diff --git a/app/handler/error_handler.py b/app/handler/error_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b7ab6d88521a145e0dd451ac7724d98769cffca2 --- /dev/null +++ b/app/handler/error_handler.py @@ -0,0 +1,32 @@ +from contextlib import asynccontextmanager +from fastapi import HTTPException +import logging + +@asynccontextmanager +async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None): + """ + 一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。 + + Args: + logger: 用于记录日志的 Logger 实例。 + operation_name: 操作的名称,用于日志记录和错误详情。 + success_message: 操作成功时记录的自定义消息 (可选)。 + failure_message: 操作失败时记录的自定义消息 (可选)。 + """ + default_success_msg = f"{operation_name} request successful" + default_failure_msg = f"{operation_name} request failed" + + logger.info("-" * 50 + operation_name + "-" * 50) + try: + yield + logger.info(success_message or default_success_msg) + except HTTPException as http_exc: + # 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情 + logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})") + raise http_exc + except Exception as e: + # 对于其他所有异常,记录错误并抛出标准的 500 错误 + logger.error(f"{failure_message or default_failure_msg}: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Internal server error during {operation_name}" + ) from e \ No newline at end of file diff --git a/app/handler/message_converter.py b/app/handler/message_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..378871a2137ad82728cba13c0ae574cadcca0014 --- /dev/null +++ b/app/handler/message_converter.py @@ -0,0 +1,349 @@ +import base64 +import json +import re +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import requests + +from app.core.constants import ( + AUDIO_FORMAT_TO_MIMETYPE, + DATA_URL_PATTERN, + IMAGE_URL_PATTERN, + MAX_AUDIO_SIZE_BYTES, + MAX_VIDEO_SIZE_BYTES, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_ROLES, + SUPPORTED_VIDEO_FORMATS, + VIDEO_FORMAT_TO_MIMETYPE, +) +from app.log.logger import get_message_converter_logger + +logger = get_message_converter_logger() + + +class MessageConverter(ABC): + """消息转换器基类""" + + @abstractmethod + def convert( + self, messages: List[Dict[str, Any]] + ) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + pass + + +def _get_mime_type_and_data(base64_string): + """ + 从 base64 字符串中提取 MIME 类型和数据。 + + 参数: + base64_string (str): 可能包含 MIME 类型信息的 base64 字符串 + + 返回: + tuple: (mime_type, encoded_data) + """ + # 检查字符串是否以 "data:" 格式开始 + if base64_string.startswith("data:"): + # 提取 MIME 类型和数据 + pattern = DATA_URL_PATTERN + match = re.match(pattern, base64_string) + if match: + mime_type = ( + "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) + ) + encoded_data = match.group(2) + return mime_type, encoded_data + + # 如果不是预期格式,假定它只是数据部分 + return None, base64_string + + +def _convert_image(image_url: str) -> Dict[str, Any]: + if image_url.startswith("data:image"): + mime_type, encoded_data = _get_mime_type_and_data(image_url) + return {"inline_data": {"mime_type": mime_type, "data": encoded_data}} + else: + encoded_data = _convert_image_to_base64(image_url) + return {"inline_data": {"mime_type": "image/png", "data": encoded_data}} + + +def _convert_image_to_base64(url: str) -> str: + """ + 将图片URL转换为base64编码 + Args: + url: 图片URL + Returns: + str: base64编码的图片数据 + """ + response = requests.get(url) + if response.status_code == 200: + # 将图片内容转换为base64 + img_data = base64.b64encode(response.content).decode("utf-8") + return img_data + else: + raise Exception(f"Failed to fetch image: {response.status_code}") + + +def _process_text_with_image(text: str) -> List[Dict[str, Any]]: + """ + 处理可能包含图片URL的文本,提取图片并转换为base64 + + Args: + text: 可能包含图片URL的文本 + + Returns: + List[Dict[str, Any]]: 包含文本和图片的部分列表 + """ + parts = [] + img_url_match = re.search(IMAGE_URL_PATTERN, text) + if img_url_match: + # 提取URL + img_url = img_url_match.group(2) + # 将URL对应的图片转换为base64 + try: + base64_data = _convert_image_to_base64(img_url) + parts.append( + {"inline_data": {"mimeType": "image/png", "data": base64_data}} + ) + except Exception: + # 如果转换失败,回退到文本模式 + parts.append({"text": text}) + else: + # 没有图片URL,作为纯文本处理 + parts.append({"text": text}) + return parts + + +class OpenAIMessageConverter(MessageConverter): + """OpenAI消息格式转换器""" + + def _validate_media_data( + self, format: str, data: str, supported_formats: List[str], max_size: int + ) -> tuple[Optional[str], Optional[str]]: + """Validates format and size of Base64 media data.""" + if format.lower() not in supported_formats: + logger.error( + f"Unsupported media format: {format}. Supported: {supported_formats}" + ) + raise ValueError(f"Unsupported media format: {format}") + + try: + decoded_data = base64.b64decode(data, validate=True) + if len(decoded_data) > max_size: + logger.error( + f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)." + ) + raise ValueError( + f"Media data size exceeds limit of {max_size // 1024 // 1024}MB" + ) + return data + except base64.binascii.Error as e: + logger.error(f"Invalid Base64 data provided: {e}") + raise ValueError("Invalid Base64 data") + except Exception as e: + logger.error(f"Error validating media data: {e}") + raise + + def convert( + self, messages: List[Dict[str, Any]] + ) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: + converted_messages = [] + system_instruction_parts = [] + + for idx, msg in enumerate(messages): + role = msg.get("role", "") + parts = [] + + if "content" in msg and isinstance(msg["content"], list): + for content_item in msg["content"]: + if not isinstance(content_item, dict): + logger.warning( + f"Skipping unexpected content item format: {type(content_item)}" + ) + continue + + content_type = content_item.get("type") + + if content_type == "text" and content_item.get("text"): + parts.append({"text": content_item["text"]}) + elif content_type == "image_url" and content_item.get( + "image_url", {} + ).get("url"): + try: + parts.append( + _convert_image(content_item["image_url"]["url"]) + ) + except Exception as e: + logger.error( + f"Failed to convert image URL {content_item['image_url']['url']}: {e}" + ) + parts.append( + { + "text": f"[Error processing image: {content_item['image_url']['url']}]" + } + ) + elif content_type == "input_audio" and content_item.get( + "input_audio" + ): + audio_info = content_item["input_audio"] + audio_data = audio_info.get("data") + audio_format = audio_info.get("format", "").lower() + + if not audio_data or not audio_format: + logger.warning( + "Skipping audio part due to missing data or format." + ) + continue + + try: + validated_data = self._validate_media_data( + audio_format, + audio_data, + SUPPORTED_AUDIO_FORMATS, + MAX_AUDIO_SIZE_BYTES, + ) + + # Get MIME type + mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format) + if not mime_type: + # Should not happen if format validation passed, but double-check + logger.error( + f"Could not find MIME type for supported format: {audio_format}" + ) + raise ValueError( + f"Internal error: MIME type mapping missing for {audio_format}" + ) + + parts.append( + { + "inline_data": { + "mimeType": mime_type, + "data": validated_data, # Use the validated Base64 data + } + } + ) + logger.debug( + f"Successfully added audio part (format: {audio_format})" + ) + + except ValueError as e: + logger.error( + f"Skipping audio part due to validation error: {e}" + ) + parts.append({"text": f"[Error processing audio: {e}]"}) + except Exception: + logger.exception("Unexpected error processing audio part.") + parts.append( + {"text": "[Unexpected error processing audio]"} + ) + + elif content_type == "input_video" and content_item.get( + "input_video" + ): + video_info = content_item["input_video"] + video_data = video_info.get("data") + video_format = video_info.get("format", "").lower() + + if not video_data or not video_format: + logger.warning( + "Skipping video part due to missing data or format." + ) + continue + + try: + validated_data = self._validate_media_data( + video_format, + video_data, + SUPPORTED_VIDEO_FORMATS, + MAX_VIDEO_SIZE_BYTES, + ) + mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format) + if not mime_type: + raise ValueError( + f"Internal error: MIME type mapping missing for {video_format}" + ) + + parts.append( + { + "inline_data": { + "mimeType": mime_type, + "data": validated_data, + } + } + ) + logger.debug( + f"Successfully added video part (format: {video_format})" + ) + + except ValueError as e: + logger.error( + f"Skipping video part due to validation error: {e}" + ) + parts.append({"text": f"[Error processing video: {e}]"}) + except Exception: + logger.exception("Unexpected error processing video part.") + parts.append( + {"text": "[Unexpected error processing video]"} + ) + + else: + # Log unrecognized but present types + if content_type: + logger.warning( + f"Unsupported content type or missing data in structured content: {content_type}" + ) + + elif ( + "content" in msg and isinstance(msg["content"], str) and msg["content"] + ): + parts.extend(_process_text_with_image(msg["content"])) + elif "tool_calls" in msg and isinstance(msg["tool_calls"], list): + # Keep existing tool call processing + for tool_call in msg["tool_calls"]: + function_call = tool_call.get("function", {}) + # Sanitize arguments loading + arguments_str = function_call.get("arguments", "{}") + try: + function_call["args"] = json.loads(arguments_str) + except json.JSONDecodeError: + logger.warning( + f"Failed to decode tool call arguments: {arguments_str}" + ) + function_call["args"] = {} + if "arguments" in function_call: + if "arguments" in function_call: + del function_call["arguments"] + + parts.append({"functionCall": function_call}) + + if role not in SUPPORTED_ROLES: + if role == "tool": + role = "user" + else: + # 如果是最后一条消息,则认为是用户消息 + if idx == len(messages) - 1: + role = "user" + else: + role = "model" + if parts: + if role == "system": + text_only_parts = [p for p in parts if "text" in p] + if len(text_only_parts) != len(parts): + logger.warning( + "Non-text parts found in system message; discarding them." + ) + if text_only_parts: + system_instruction_parts.extend(text_only_parts) + + else: + converted_messages.append({"role": role, "parts": parts}) + + system_instruction = ( + None + if not system_instruction_parts + else { + "role": "system", + "parts": system_instruction_parts, + } + ) + return converted_messages, system_instruction diff --git a/app/handler/response_handler.py b/app/handler/response_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f61f11a3b491bbf91c60d1f07e97b0576baac9 --- /dev/null +++ b/app/handler/response_handler.py @@ -0,0 +1,360 @@ +import base64 +import json +import random +import string +import time +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from app.config.config import settings +from app.utils.uploader import ImageUploaderFactory + + +class ResponseHandler(ABC): + """响应处理器基类""" + + @abstractmethod + def handle_response( + self, response: Dict[str, Any], model: str, stream: bool = False + ) -> Dict[str, Any]: + pass + + +class GeminiResponseHandler(ResponseHandler): + """Gemini响应处理器""" + + def __init__(self): + self.thinking_first = True + self.thinking_status = False + + def handle_response( + self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + if stream: + return _handle_gemini_stream_response(response, model, stream) + return _handle_gemini_normal_response(response, model, stream) + + +def _handle_openai_stream_response( + response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + text, tool_calls, _ = _extract_result( + response, model, stream=True, gemini_format=False + ) + if not text and not tool_calls: + delta = {} + else: + delta = {"content": text, "role": "assistant"} + if tool_calls: + delta["tool_calls"] = tool_calls + template_chunk = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], + } + if usage_metadata: + template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)} + return template_chunk + + +def _handle_openai_normal_response( + response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + text, tool_calls, _ = _extract_result( + response, model, stream=False, gemini_format=False + ) + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": text, + "tool_calls": tool_calls, + }, + "finish_reason": finish_reason, + } + ], + "usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}, + } + + +class OpenAIResponseHandler(ResponseHandler): + """OpenAI响应处理器""" + + def __init__(self, config): + self.config = config + self.thinking_first = True + self.thinking_status = False + + def handle_response( + self, + response: Dict[str, Any], + model: str, + stream: bool = False, + finish_reason: str = None, + usage_metadata: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + if stream: + return _handle_openai_stream_response(response, model, finish_reason, usage_metadata) + return _handle_openai_normal_response(response, model, finish_reason, usage_metadata) + + def handle_image_chat_response( + self, image_str: str, model: str, stream=False, finish_reason="stop" + ): + if stream: + return _handle_openai_stream_image_response(image_str, model, finish_reason) + return _handle_openai_normal_image_response(image_str, model, finish_reason) + + +def _handle_openai_stream_image_response( + image_str: str, model: str, finish_reason: str +) -> Dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": image_str} if image_str else {}, + "finish_reason": finish_reason, + } + ], + } + + +def _handle_openai_normal_image_response( + image_str: str, model: str, finish_reason: str +) -> Dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": image_str}, + "finish_reason": finish_reason, + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +def _extract_result( + response: Dict[str, Any], + model: str, + stream: bool = False, + gemini_format: bool = False, +) -> tuple[str, List[Dict[str, Any]], Optional[bool]]: + text, tool_calls = "", [] + thought = None + if stream: + if response.get("candidates"): + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + if not parts: + return "", [], None + if "text" in parts[0]: + text = parts[0].get("text") + if "thought" in parts[0]: + thought = parts[0].get("thought") + elif "executableCode" in parts[0]: + text = _format_code_block(parts[0]["executableCode"]) + elif "codeExecution" in parts[0]: + text = _format_code_block(parts[0]["codeExecution"]) + elif "executableCodeResult" in parts[0]: + text = _format_execution_result(parts[0]["executableCodeResult"]) + elif "codeExecutionResult" in parts[0]: + text = _format_execution_result(parts[0]["codeExecutionResult"]) + elif "inlineData" in parts[0]: + text = _extract_image_data(parts[0]) + else: + text = "" + text = _add_search_link_text(model, candidate, text) + tool_calls = _extract_tool_calls(parts, gemini_format) + else: + if response.get("candidates"): + candidate = response["candidates"][0] + if "thinking" in model: + if settings.SHOW_THINKING_PROCESS: + if len(candidate["content"]["parts"]) == 2: + text = ( + "> thinking\n\n" + + candidate["content"]["parts"][0]["text"] + + "\n\n---\n> output\n\n" + + candidate["content"]["parts"][1]["text"] + ) + else: + text = candidate["content"]["parts"][0]["text"] + else: + if len(candidate["content"]["parts"]) == 2: + text = candidate["content"]["parts"][1]["text"] + else: + text = candidate["content"]["parts"][0]["text"] + else: + text = "" + if "parts" in candidate["content"]: + for part in candidate["content"]["parts"]: + if "text" in part: + text += part["text"] + if "thought" in part and thought is None: + thought = part.get("thought") + elif "inlineData" in part: + text += _extract_image_data(part) + + text = _add_search_link_text(model, candidate, text) + tool_calls = _extract_tool_calls( + candidate["content"]["parts"], gemini_format + ) + else: + text = "暂无返回" + return text, tool_calls, thought + + +def _extract_image_data(part: dict) -> str: + image_uploader = None + if settings.UPLOAD_PROVIDER == "smms": + image_uploader = ImageUploaderFactory.create( + provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN + ) + elif settings.UPLOAD_PROVIDER == "picgo": + image_uploader = ImageUploaderFactory.create( + provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY + ) + elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed": + image_uploader = ImageUploaderFactory.create( + provider=settings.UPLOAD_PROVIDER, + base_url=settings.CLOUDFLARE_IMGBED_URL, + auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE, + ) + current_date = time.strftime("%Y/%m/%d") + filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" + base64_data = part["inlineData"]["data"] + # 将base64_data转成bytes数组 + bytes_data = base64.b64decode(base64_data) + upload_response = image_uploader.upload(bytes_data, filename) + if upload_response.success: + text = f"\n\n![image]({upload_response.data.url})\n\n" + else: + text = "" + return text + + +def _extract_tool_calls( + parts: List[Dict[str, Any]], gemini_format: bool +) -> List[Dict[str, Any]]: + """提取工具调用信息""" + if not parts or not isinstance(parts, list): + return [] + + letters = string.ascii_lowercase + string.digits + + tool_calls = list() + for i in range(len(parts)): + part = parts[i] + if not part or not isinstance(part, dict): + continue + + item = part.get("functionCall", {}) + if not item or not isinstance(item, dict): + continue + + if gemini_format: + tool_calls.append(part) + else: + id = f"call_{''.join(random.sample(letters, 32))}" + name = item.get("name", "") + arguments = json.dumps(item.get("args", None) or {}) + + tool_calls.append( + { + "index": i, + "id": id, + "type": "function", + "function": {"name": name, "arguments": arguments}, + } + ) + + return tool_calls + + +def _handle_gemini_stream_response( + response: Dict[str, Any], model: str, stream: bool +) -> Dict[str, Any]: + text, tool_calls, thought = _extract_result( + response, model, stream=stream, gemini_format=True + ) + if tool_calls: + content = {"parts": tool_calls, "role": "model"} + else: + part = {"text": text} + if thought is not None: + part["thought"] = thought + content = {"parts": [part], "role": "model"} + response["candidates"][0]["content"] = content + return response + + +def _handle_gemini_normal_response( + response: Dict[str, Any], model: str, stream: bool +) -> Dict[str, Any]: + text, tool_calls, thought = _extract_result( + response, model, stream=stream, gemini_format=True + ) + if tool_calls: + content = {"parts": tool_calls, "role": "model"} + else: + part = {"text": text} + if thought is not None: + part["thought"] = thought + content = {"parts": [part], "role": "model"} + response["candidates"][0]["content"] = content + return response + + +def _format_code_block(code_data: dict) -> str: + """格式化代码块输出""" + language = code_data.get("language", "").lower() + code = code_data.get("code", "").strip() + return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n""" + + +def _add_search_link_text(model: str, candidate: dict, text: str) -> str: + if ( + settings.SHOW_SEARCH_LINK + and model.endswith("-search") + and "groundingMetadata" in candidate + and "groundingChunks" in candidate["groundingMetadata"] + ): + grounding_chunks = candidate["groundingMetadata"]["groundingChunks"] + text += "\n\n---\n\n" + text += "**【引用来源】**\n\n" + for _, grounding_chunk in enumerate(grounding_chunks, 1): + if "web" in grounding_chunk: + text += _create_search_link(grounding_chunk["web"]) + return text + else: + return text + + +def _create_search_link(grounding_chunk: dict) -> str: + return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})' + + +def _format_execution_result(result_data: dict) -> str: + """格式化执行结果输出""" + outcome = result_data.get("outcome", "") + output = result_data.get("output", "").strip() + return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n""" diff --git a/app/handler/retry_handler.py b/app/handler/retry_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0161e27d2d3b7701393ca06b091504e03fe03ada --- /dev/null +++ b/app/handler/retry_handler.py @@ -0,0 +1,50 @@ + +from functools import wraps +from typing import Callable, TypeVar + +from app.config.config import settings +from app.log.logger import get_retry_logger + +T = TypeVar("T") +logger = get_retry_logger() + + +class RetryHandler: + """重试处理装饰器""" + + def __init__(self, key_arg: str = "api_key"): + self.key_arg = key_arg + + def __call__(self, func: Callable[..., T]) -> Callable[..., T]: + @wraps(func) + async def wrapper(*args, **kwargs) -> T: + last_exception = None + + for attempt in range(settings.MAX_RETRIES): + retries = attempt + 1 + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + logger.warning( + f"API call failed with error: {str(e)}. Attempt {retries} of {settings.MAX_RETRIES}" + ) + + # 从函数参数中获取 key_manager + key_manager = kwargs.get("key_manager") + if key_manager: + old_key = kwargs.get(self.key_arg) + new_key = await key_manager.handle_api_failure(old_key, retries) + if new_key: + kwargs[self.key_arg] = new_key + logger.info(f"Switched to new API key: {new_key}") + else: + logger.error(f"No valid API key available after {retries} retries.") + break + + logger.error( + f"All retry attempts failed, raising final exception: {str(last_exception)}" + ) + raise last_exception + + return wrapper diff --git a/app/handler/stream_optimizer.py b/app/handler/stream_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c27ff20d715072873479685a20acd2fac6f2db60 --- /dev/null +++ b/app/handler/stream_optimizer.py @@ -0,0 +1,143 @@ + +import asyncio +import math +from typing import Any, AsyncGenerator, Callable, List + +from app.config.config import settings +from app.core.constants import ( + DEFAULT_STREAM_CHUNK_SIZE, + DEFAULT_STREAM_LONG_TEXT_THRESHOLD, + DEFAULT_STREAM_MAX_DELAY, + DEFAULT_STREAM_MIN_DELAY, + DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, +) +from app.log.logger import get_gemini_logger, get_openai_logger + +logger_openai = get_openai_logger() +logger_gemini = get_gemini_logger() + + +class StreamOptimizer: + """流式输出优化器 + + 提供流式输出优化功能,包括智能延迟调整和长文本分块输出。 + """ + + def __init__( + self, + logger=None, + min_delay: float = DEFAULT_STREAM_MIN_DELAY, + max_delay: float = DEFAULT_STREAM_MAX_DELAY, + short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, + long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD, + chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE, + ): + """初始化流式输出优化器 + + 参数: + logger: 日志记录器 + min_delay: 最小延迟时间(秒) + max_delay: 最大延迟时间(秒) + short_text_threshold: 短文本阈值(字符数) + long_text_threshold: 长文本阈值(字符数) + chunk_size: 长文本分块大小(字符数) + """ + self.logger = logger + self.min_delay = min_delay + self.max_delay = max_delay + self.short_text_threshold = short_text_threshold + self.long_text_threshold = long_text_threshold + self.chunk_size = chunk_size + + def calculate_delay(self, text_length: int) -> float: + """根据文本长度计算延迟时间 + + 参数: + text_length: 文本长度 + + 返回: + 延迟时间(秒) + """ + if text_length <= self.short_text_threshold: + # 短文本使用较大延迟 + return self.max_delay + elif text_length >= self.long_text_threshold: + # 长文本使用较小延迟 + return self.min_delay + else: + # 中等长度文本使用线性插值计算延迟 + # 使用对数函数使延迟变化更平滑 + ratio = math.log(text_length / self.short_text_threshold) / math.log( + self.long_text_threshold / self.short_text_threshold + ) + return self.max_delay - ratio * (self.max_delay - self.min_delay) + + def split_text_into_chunks(self, text: str) -> List[str]: + """将文本分割成小块 + + 参数: + text: 要分割的文本 + + 返回: + 文本块列表 + """ + return [ + text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size) + ] + + async def optimize_stream_output( + self, + text: str, + create_response_chunk: Callable[[str], Any], + format_chunk: Callable[[Any], str], + ) -> AsyncGenerator[str, None]: + """优化流式输出 + + 参数: + text: 要输出的文本 + create_response_chunk: 创建响应块的函数,接收文本,返回响应块 + format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串 + + 返回: + 异步生成器,生成格式化后的响应块 + """ + if not text: + return + + # 计算智能延迟时间 + delay = self.calculate_delay(len(text)) + + # 根据文本长度决定输出方式 + if len(text) >= self.long_text_threshold: + # 长文本:分块输出 + chunks = self.split_text_into_chunks(text) + for chunk_text in chunks: + chunk_response = create_response_chunk(chunk_text) + yield format_chunk(chunk_response) + await asyncio.sleep(delay) + else: + # 短文本:逐字符输出 + for char in text: + char_chunk = create_response_chunk(char) + yield format_chunk(char_chunk) + await asyncio.sleep(delay) + + +# 创建默认的优化器实例,可以直接导入使用 +openai_optimizer = StreamOptimizer( + logger=logger_openai, + min_delay=settings.STREAM_MIN_DELAY, + max_delay=settings.STREAM_MAX_DELAY, + short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD, + long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD, + chunk_size=settings.STREAM_CHUNK_SIZE, +) + +gemini_optimizer = StreamOptimizer( + logger=logger_gemini, + min_delay=settings.STREAM_MIN_DELAY, + max_delay=settings.STREAM_MAX_DELAY, + short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD, + long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD, + chunk_size=settings.STREAM_CHUNK_SIZE, +) diff --git a/app/log/logger.py b/app/log/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..1614a4664869dbb9c377b937656812dc3efd7e07 --- /dev/null +++ b/app/log/logger.py @@ -0,0 +1,233 @@ +import logging +import platform +import sys +from typing import Dict, Optional + +# ANSI转义序列颜色代码 +COLORS = { + "DEBUG": "\033[34m", # 蓝色 + "INFO": "\033[32m", # 绿色 + "WARNING": "\033[33m", # 黄色 + "ERROR": "\033[31m", # 红色 + "CRITICAL": "\033[1;31m", # 红色加粗 +} + +# Windows系统启用ANSI支持 +if platform.system() == "Windows": + import ctypes + + kernel32 = ctypes.windll.kernel32 + kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7) + + +class ColoredFormatter(logging.Formatter): + """ + 自定义的日志格式化器,添加颜色支持 + """ + + def format(self, record): + # 获取对应级别的颜色代码 + color = COLORS.get(record.levelname, "") + # 添加颜色代码和重置代码 + record.levelname = f"{color}{record.levelname}\033[0m" + # 创建包含文件名和行号的固定宽度字符串 + record.fileloc = f"[{record.filename}:{record.lineno}]" + return super().format(record) + + +# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30) +FORMATTER = ColoredFormatter( + "%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s" +) + +# 日志级别映射 +LOG_LEVELS = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + + +class Logger: + def __init__(self): + pass + + _loggers: Dict[str, logging.Logger] = {} + + @staticmethod + def setup_logger(name: str) -> logging.Logger: + """ + 设置并获取logger + :param name: logger名称 + :return: logger实例 + """ + # 导入 settings 对象 + from app.config.config import settings + + # 从全局配置获取日志级别 + log_level_str = settings.LOG_LEVEL.lower() + level = LOG_LEVELS.get(log_level_str, logging.INFO) + + if name in Logger._loggers: + # 如果 logger 已存在,检查并更新其级别(如果需要) + existing_logger = Logger._loggers[name] + if existing_logger.level != level: + existing_logger.setLevel(level) + return existing_logger + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + + # 添加控制台输出 + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(FORMATTER) + logger.addHandler(console_handler) + + Logger._loggers[name] = logger + return logger + + @staticmethod + def get_logger(name: str) -> Optional[logging.Logger]: + """ + 获取已存在的logger + :param name: logger名称 + :return: logger实例或None + """ + return Logger._loggers.get(name) + + @staticmethod + def update_log_levels(log_level: str): + """ + 根据当前的全局配置更新所有已创建 logger 的日志级别。 + """ + log_level_str = log_level.lower() + new_level = LOG_LEVELS.get(log_level_str, logging.INFO) + + updated_count = 0 + for logger_name, logger_instance in Logger._loggers.items(): + if logger_instance.level != new_level: + logger_instance.setLevel(new_level) + # 可选:记录级别变更日志,但注意避免在日志模块内部产生过多日志 + # print(f"Updated log level for logger '{logger_name}' to {log_level_str.upper()}") + updated_count += 1 + + +# 预定义的loggers +def get_openai_logger(): + return Logger.setup_logger("openai") + + +def get_gemini_logger(): + return Logger.setup_logger("gemini") + + +def get_chat_logger(): + return Logger.setup_logger("chat") + + +def get_model_logger(): + return Logger.setup_logger("model") + + +def get_security_logger(): + return Logger.setup_logger("security") + + +def get_key_manager_logger(): + return Logger.setup_logger("key_manager") + + +def get_main_logger(): + return Logger.setup_logger("main") + + +def get_embeddings_logger(): + return Logger.setup_logger("embeddings") + + +def get_request_logger(): + return Logger.setup_logger("request") + + +def get_retry_logger(): + return Logger.setup_logger("retry") + + +def get_image_create_logger(): + return Logger.setup_logger("image_create") + + +def get_exceptions_logger(): + return Logger.setup_logger("exceptions") + + +def get_application_logger(): + return Logger.setup_logger("application") + + +def get_initialization_logger(): + return Logger.setup_logger("initialization") + + +def get_middleware_logger(): + return Logger.setup_logger("middleware") + + +def get_routes_logger(): + return Logger.setup_logger("routes") + + +def get_config_routes_logger(): + return Logger.setup_logger("config_routes") + + +def get_config_logger(): + return Logger.setup_logger("config") + + +def get_database_logger(): + return Logger.setup_logger("database") + + +def get_log_routes_logger(): + return Logger.setup_logger("log_routes") + + +def get_stats_logger(): + return Logger.setup_logger("stats") + + +def get_update_logger(): + return Logger.setup_logger("update_service") + + +def get_scheduler_routes(): + return Logger.setup_logger("scheduler_routes") + + +def get_message_converter_logger(): + return Logger.setup_logger("message_converter") + + +def get_api_client_logger(): + return Logger.setup_logger("api_client") + + +def get_openai_compatible_logger(): + return Logger.setup_logger("openai_compatible") + + +def get_error_log_logger(): + return Logger.setup_logger("error_log") + + +def get_request_log_logger(): + return Logger.setup_logger("request_log") + + +def get_vertex_express_logger(): + return Logger.setup_logger("vertex_express") + diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9b50fb98afa8d236e295c6b769785dbc85c6a --- /dev/null +++ b/app/main.py @@ -0,0 +1,15 @@ +import uvicorn +from dotenv import load_dotenv + +# 在导入应用程序配置之前加载 .env 文件到环境变量 +load_dotenv() + +from app.core.application import create_app +from app.log.logger import get_main_logger + +app = create_app() + +if __name__ == "__main__": + logger = get_main_logger() + logger.info("Starting application server...") + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/app/middleware/middleware.py b/app/middleware/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..85d512f767a3cbb0b05ab5ece6eee8c453546b20 --- /dev/null +++ b/app/middleware/middleware.py @@ -0,0 +1,80 @@ +""" +中间件配置模块,负责设置和配置应用程序的中间件 +""" + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from starlette.middleware.base import BaseHTTPMiddleware + +# from app.middleware.request_logging_middleware import RequestLoggingMiddleware +from app.middleware.smart_routing_middleware import SmartRoutingMiddleware +from app.core.constants import API_VERSION +from app.core.security import verify_auth_token +from app.log.logger import get_middleware_logger + +logger = get_middleware_logger() + + +class AuthMiddleware(BaseHTTPMiddleware): + """ + 认证中间件,处理未经身份验证的请求 + """ + + async def dispatch(self, request: Request, call_next): + # 允许特定路径绕过身份验证 + if ( + request.url.path not in ["/", "/auth"] + and not request.url.path.startswith("/static") + and not request.url.path.startswith("/gemini") + and not request.url.path.startswith("/v1") + and not request.url.path.startswith(f"/{API_VERSION}") + and not request.url.path.startswith("/health") + and not request.url.path.startswith("/hf") + and not request.url.path.startswith("/openai") + and not request.url.path.startswith("/api/version/check") + and not request.url.path.startswith("/vertex-express") + ): + + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning(f"Unauthorized access attempt to {request.url.path}") + return RedirectResponse(url="/") + logger.debug("Request authenticated successfully") + + response = await call_next(request) + return response + + +def setup_middlewares(app: FastAPI) -> None: + """ + 设置应用程序的中间件 + + Args: + app: FastAPI应用程序实例 + """ + # 添加智能路由中间件(必须在认证中间件之前) + app.add_middleware(SmartRoutingMiddleware) + + # 添加认证中间件 + app.add_middleware(AuthMiddleware) + + # 添加请求日志中间件(可选,默认注释掉) + # app.add_middleware(RequestLoggingMiddleware) + + # 配置CORS中间件 + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=[ + "GET", + "POST", + "PUT", + "DELETE", + "OPTIONS", + ], + allow_headers=["*"], + expose_headers=["*"], + max_age=600, + ) diff --git a/app/middleware/request_logging_middleware.py b/app/middleware/request_logging_middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..be1ec4e18545594a867019ea43fa8ad251783bde --- /dev/null +++ b/app/middleware/request_logging_middleware.py @@ -0,0 +1,40 @@ +import json + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from app.log.logger import get_request_logger + +logger = get_request_logger() + + +# 添加中间件类 +class RequestLoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # 记录请求路径 + logger.info(f"Request path: {request.url.path}") + + # 获取并记录请求体 + try: + body = await request.body() + if body: + body_str = body.decode() + # 尝试格式化JSON + try: + formatted_body = json.loads(body_str) + logger.info( + f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}" + ) + except json.JSONDecodeError: + logger.error("Request body is not valid JSON.") + except Exception as e: + logger.error(f"Error reading request body: {str(e)}") + + # 重置请求的接收器,以便后续处理器可以继续读取请求体 + async def receive(): + return {"type": "http.request", "body": body, "more_body": False} + + request._receive = receive + + response = await call_next(request) + return response diff --git a/app/middleware/smart_routing_middleware.py b/app/middleware/smart_routing_middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..4fac42f57e192c0d3a183c0cf5cf36a2c71b4d08 --- /dev/null +++ b/app/middleware/smart_routing_middleware.py @@ -0,0 +1,210 @@ +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from app.config.config import settings +from app.log.logger import get_main_logger +import re + +logger = get_main_logger() + +class SmartRoutingMiddleware(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + # 简化的路由规则 - 直接根据检测结果路由 + pass + + async def dispatch(self, request: Request, call_next): + if not settings.URL_NORMALIZATION_ENABLED: + return await call_next(request) + logger.debug(f"request: {request}") + original_path = str(request.url.path) + method = request.method + + # 尝试修复URL + fixed_path, fix_info = self.fix_request_url(original_path, method, request) + + if fixed_path != original_path: + logger.info(f"URL fixed: {method} {original_path} → {fixed_path}") + if fix_info: + logger.debug(f"Fix details: {fix_info}") + + # 重写请求路径 + request.scope["path"] = fixed_path + request.scope["raw_path"] = fixed_path.encode() + + return await call_next(request) + + def fix_request_url(self, path: str, method: str, request: Request) -> tuple: + """简化的URL修复逻辑""" + + # 首先检查是否已经是正确的格式,如果是则不处理 + if self.is_already_correct_format(path): + return path, None + + # 1. 最高优先级:包含generateContent → Gemini格式 + if "generatecontent" in path.lower() or "v1beta/models" in path.lower(): + return self.fix_gemini_by_operation(path, method, request) + + # 2. 第二优先级:包含/openai/ → OpenAI格式 + if "/openai/" in path.lower(): + return self.fix_openai_by_operation(path, method) + + # 3. 第三优先级:包含/v1/ → v1格式 + if "/v1/" in path.lower(): + return self.fix_v1_by_operation(path, method) + + # 4. 第四优先级:包含/chat/completions → chat功能 + if "/chat/completions" in path.lower(): + return "/v1/chat/completions", {"type": "v1_chat"} + + # 5. 默认:原样传递 + return path, None + + def is_already_correct_format(self, path: str) -> bool: + """检查是否已经是正确的API格式""" + # 检查是否已经是正确的端点格式 + correct_patterns = [ + r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生 + r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀 + r"^/v1beta/models$", # Gemini模型列表 + r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表 + r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式 + r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式 + r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式 + r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式 + r"^/vertex-express/v1beta/models$", # Vertex Express模型列表 + r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式 + ] + + for pattern in correct_patterns: + if re.match(pattern, path): + return True + + return False + + def fix_gemini_by_operation( + self, path: str, method: str, request: Request + ) -> tuple: + """根据Gemini操作修复,考虑端点偏好""" + if method == "GET": + return "/v1beta/models", { + "role": "gemini_models", + } + + # 提取模型名称 + try: + model_name = self.extract_model_name(path, request) + except ValueError: + # 无法提取模型名称,返回原路径不做处理 + return path, None + + # 检测是否为流式请求 + is_stream = self.detect_stream_request(path, request) + + # 检查是否有vertex-express偏好 + if "/vertex-express/" in path.lower(): + if is_stream: + target_url = ( + f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent" + ) + else: + target_url = ( + f"/vertex-express/v1beta/models/{model_name}:generateContent" + ) + + fix_info = { + "rule": ( + "vertex_express_generate" + if not is_stream + else "vertex_express_stream" + ), + "preference": "vertex_express_format", + "is_stream": is_stream, + "model": model_name, + } + else: + # 标准Gemini端点 + if is_stream: + target_url = f"/v1beta/models/{model_name}:streamGenerateContent" + else: + target_url = f"/v1beta/models/{model_name}:generateContent" + + fix_info = { + "rule": "gemini_generate" if not is_stream else "gemini_stream", + "preference": "gemini_format", + "is_stream": is_stream, + "model": model_name, + } + + return target_url, fix_info + + def fix_openai_by_operation(self, path: str, method: str) -> tuple: + """根据操作类型修复OpenAI格式""" + if method == "POST": + if "chat" in path.lower() or "completion" in path.lower(): + return "/openai/v1/chat/completions", {"type": "openai_chat"} + elif "embedding" in path.lower(): + return "/openai/v1/embeddings", {"type": "openai_embeddings"} + elif "image" in path.lower(): + return "/openai/v1/images/generations", {"type": "openai_images"} + elif "audio" in path.lower(): + return "/openai/v1/audio/speech", {"type": "openai_audio"} + elif method == "GET": + if "model" in path.lower(): + return "/openai/v1/models", {"type": "openai_models"} + + return path, None + + def fix_v1_by_operation(self, path: str, method: str) -> tuple: + """根据操作类型修复v1格式""" + if method == "POST": + if "chat" in path.lower() or "completion" in path.lower(): + return "/v1/chat/completions", {"type": "v1_chat"} + elif "embedding" in path.lower(): + return "/v1/embeddings", {"type": "v1_embeddings"} + elif "image" in path.lower(): + return "/v1/images/generations", {"type": "v1_images"} + elif "audio" in path.lower(): + return "/v1/audio/speech", {"type": "v1_audio"} + elif method == "GET": + if "model" in path.lower(): + return "/v1/models", {"type": "v1_models"} + + return path, None + + def detect_stream_request(self, path: str, request: Request) -> bool: + """检测是否为流式请求""" + # 1. 路径中包含stream关键词 + if "stream" in path.lower(): + return True + + # 2. 查询参数 + if request.query_params.get("stream") == "true": + return True + + return False + + def extract_model_name(self, path: str, request: Request) -> str: + """从请求中提取模型名称,用于构建Gemini API URL""" + # 1. 从请求体中提取 + try: + if hasattr(request, "_body") and request._body: + import json + + body = json.loads(request._body.decode()) + if "model" in body and body["model"]: + return body["model"] + except Exception: + pass + + # 2. 从查询参数中提取 + model_param = request.query_params.get("model") + if model_param: + return model_param + + # 3. 从路径中提取(用于已包含模型名称的路径) + match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE) + if match: + return match.group(1) + + # 4. 如果无法提取模型名称,抛出异常 + raise ValueError("Unable to extract model name from request") diff --git a/app/router/config_routes.py b/app/router/config_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..6c453fff52061390122dab6d01713694c01d8717 --- /dev/null +++ b/app/router/config_routes.py @@ -0,0 +1,133 @@ +""" +配置路由模块 +""" + +from typing import Any, Dict, List + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import BaseModel, Field + +from app.core.security import verify_auth_token +from app.log.logger import Logger, get_config_routes_logger +from app.service.config.config_service import ConfigService + +router = APIRouter(prefix="/api/config", tags=["config"]) + +logger = get_config_routes_logger() + + +@router.get("", response_model=Dict[str, Any]) +async def get_config(request: Request): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to config page") + return RedirectResponse(url="/", status_code=302) + return await ConfigService.get_config() + + +@router.put("", response_model=Dict[str, Any]) +async def update_config(config_data: Dict[str, Any], request: Request): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to config page") + return RedirectResponse(url="/", status_code=302) + try: + result = await ConfigService.update_config(config_data) + # 配置更新成功后,立即更新所有 logger 的级别 + Logger.update_log_levels(config_data["LOG_LEVEL"]) + logger.info("Log levels updated after configuration change.") + return result + except Exception as e: + logger.error(f"Error updating config or log levels: {e}", exc_info=True) + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/reset", response_model=Dict[str, Any]) +async def reset_config(request: Request): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to config page") + return RedirectResponse(url="/", status_code=302) + try: + return await ConfigService.reset_config() + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +class DeleteKeysRequest(BaseModel): + keys: List[str] = Field(..., description="List of API keys to delete") + + +@router.delete("/keys/{key_to_delete}", response_model=Dict[str, Any]) +async def delete_single_key(key_to_delete: str, request: Request): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning(f"Unauthorized attempt to delete key: {key_to_delete}") + return RedirectResponse(url="/", status_code=302) + try: + logger.info(f"Attempting to delete key: {key_to_delete}") + result = await ConfigService.delete_key(key_to_delete) + if not result.get("success"): + raise HTTPException( + status_code=( + 404 if "not found" in result.get("message", "").lower() else 400 + ), + detail=result.get("message"), + ) + return result + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error deleting key '{key_to_delete}': {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}") + + +@router.post("/keys/delete-selected", response_model=Dict[str, Any]) +async def delete_selected_keys_route( + delete_request: DeleteKeysRequest, request: Request +): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized attempt to bulk delete keys") + return RedirectResponse(url="/", status_code=302) + + if not delete_request.keys: + logger.warning("Attempt to bulk delete keys with an empty list.") + raise HTTPException(status_code=400, detail="No keys provided for deletion.") + + try: + logger.info(f"Attempting to bulk delete {len(delete_request.keys)} keys.") + result = await ConfigService.delete_selected_keys(delete_request.keys) + if not result.get("success") and result.get("deleted_count", 0) == 0: + raise HTTPException( + status_code=400, detail=result.get("message", "Failed to delete keys.") + ) + return result + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error bulk deleting keys: {e}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Error bulk deleting keys: {str(e)}" + ) + + +@router.get("/ui/models") +async def get_ui_models(request: Request): + auth_token_cookie = request.cookies.get("auth_token") + if not auth_token_cookie or not verify_auth_token(auth_token_cookie): + logger.warning("Unauthorized access attempt to /api/config/ui/models") + raise HTTPException(status_code=403, detail="Not authenticated") + + try: + models = await ConfigService.fetch_ui_models() + return models + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Unexpected error in /ui/models endpoint: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while fetching UI models: {str(e)}", + ) diff --git a/app/router/error_log_routes.py b/app/router/error_log_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..88e24c7cb5002d777b83e373b789352467e00ea2 --- /dev/null +++ b/app/router/error_log_routes.py @@ -0,0 +1,233 @@ +""" +日志路由模块 +""" + +from datetime import datetime +from typing import Dict, List, Optional + +from fastapi import ( + APIRouter, + Body, + HTTPException, + Path, + Query, + Request, + Response, + status, +) +from pydantic import BaseModel + +from app.core.security import verify_auth_token +from app.log.logger import get_log_routes_logger +from app.service.error_log import error_log_service + +router = APIRouter(prefix="/api/logs", tags=["logs"]) + +logger = get_log_routes_logger() + + +class ErrorLogListItem(BaseModel): + id: int + gemini_key: Optional[str] = None + error_type: Optional[str] = None + error_code: Optional[int] = None + model_name: Optional[str] = None + request_time: Optional[datetime] = None + + +class ErrorLogListResponse(BaseModel): + logs: List[ErrorLogListItem] + total: int + + +@router.get("/errors", response_model=ErrorLogListResponse) +async def get_error_logs_api( + request: Request, + limit: int = Query(10, ge=1, le=1000), + offset: int = Query(0, ge=0), + key_search: Optional[str] = Query( + None, description="Search term for Gemini key (partial match)" + ), + error_search: Optional[str] = Query( + None, description="Search term for error type or log message" + ), + error_code_search: Optional[str] = Query( + None, description="Search term for error code" + ), + start_date: Optional[datetime] = Query( + None, description="Start datetime for filtering" + ), + end_date: Optional[datetime] = Query( + None, description="End datetime for filtering" + ), + sort_by: str = Query( + "id", description="Field to sort by (e.g., 'id', 'request_time')" + ), + sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"), +): + """ + 获取错误日志列表 (返回错误码),支持过滤和排序 + + Args: + request: 请求对象 + limit: 限制数量 + offset: 偏移量 + key_search: 密钥搜索 + error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定) + error_code_search: 错误码搜索 + start_date: 开始日期 + end_date: 结束日期 + sort_by: 排序字段 + sort_order: 排序顺序 + + Returns: + ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count. + """ + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to error logs list") + raise HTTPException(status_code=401, detail="Not authenticated") + + try: + result = await error_log_service.process_get_error_logs( + limit=limit, + offset=offset, + key_search=key_search, + error_search=error_search, + error_code_search=error_code_search, + start_date=start_date, + end_date=end_date, + sort_by=sort_by, + sort_order=sort_order, + ) + logs_data = result["logs"] + total_count = result["total"] + + validated_logs = [ErrorLogListItem(**log) for log in logs_data] + return ErrorLogListResponse(logs=validated_logs, total=total_count) + except Exception as e: + logger.exception(f"Failed to get error logs list: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get error logs list: {str(e)}" + ) + + +class ErrorLogDetailResponse(BaseModel): + id: int + gemini_key: Optional[str] = None + error_type: Optional[str] = None + error_log: Optional[str] = None + request_msg: Optional[str] = None + model_name: Optional[str] = None + request_time: Optional[datetime] = None + + +@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse) +async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)): + """ + 根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg) + """ + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning( + f"Unauthorized access attempt to error log details for ID: {log_id}" + ) + raise HTTPException(status_code=401, detail="Not authenticated") + + try: + log_details = await error_log_service.process_get_error_log_details( + log_id=log_id + ) + if not log_details: + raise HTTPException(status_code=404, detail="Error log not found") + + return ErrorLogDetailResponse(**log_details) + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get error log details: {str(e)}" + ) + + +@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT) +async def delete_error_logs_bulk_api( + request: Request, payload: Dict[str, List[int]] = Body(...) +): + """ + 批量删除错误日志 (异步) + """ + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to bulk delete error logs") + raise HTTPException(status_code=401, detail="Not authenticated") + + log_ids = payload.get("ids") + if not log_ids: + raise HTTPException(status_code=400, detail="No log IDs provided for deletion.") + + try: + deleted_count = await error_log_service.process_delete_error_logs_by_ids( + log_ids + ) + # 注意:异步函数返回的是尝试删除的数量,可能不是精确值 + logger.info( + f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}" + ) + return Response(status_code=status.HTTP_204_NO_CONTENT) + except Exception as e: + logger.exception(f"Error bulk deleting error logs with IDs {log_ids}: {str(e)}") + raise HTTPException( + status_code=500, detail="Internal server error during bulk deletion" + ) + + +@router.delete("/errors/all", status_code=status.HTTP_204_NO_CONTENT) +async def delete_all_error_logs_api(request: Request): + """ + 删除所有错误日志 (异步) + """ + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to delete all error logs") + raise HTTPException(status_code=401, detail="Not authenticated") + + try: + deleted_count = await error_log_service.process_delete_all_error_logs() + logger.info(f"Successfully deleted all {deleted_count} error logs.") + # No body needed for 204 response + return Response(status_code=status.HTTP_204_NO_CONTENT) + except Exception as e: + logger.exception(f"Error deleting all error logs: {str(e)}") + raise HTTPException( + status_code=500, detail="Internal server error during deletion of all logs" + ) + + +@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)): + """ + 删除单个错误日志 (异步) + """ + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}") + raise HTTPException(status_code=401, detail="Not authenticated") + + try: + success = await error_log_service.process_delete_error_log_by_id(log_id) + if not success: + # 服务层现在在未找到时返回 False,我们在这里转换为 404 + raise HTTPException( + status_code=404, detail=f"Error log with ID {log_id} not found" + ) + logger.info(f"Successfully deleted error log with ID: {log_id}") + return Response(status_code=status.HTTP_204_NO_CONTENT) + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}") + raise HTTPException( + status_code=500, detail="Internal server error during deletion" + ) diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..95bf88af63ad2ed5a470d7ffd5efcf11689cab37 --- /dev/null +++ b/app/router/gemini_routes.py @@ -0,0 +1,374 @@ +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +from copy import deepcopy +import asyncio +from app.config.config import settings +from app.log.logger import get_gemini_logger +from app.core.security import SecurityService +from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest +from app.service.chat.gemini_chat_service import GeminiChatService +from app.service.key.key_manager import KeyManager, get_key_manager_instance +from app.service.model.model_service import ModelService +from app.handler.retry_handler import RetryHandler +from app.handler.error_handler import handle_route_errors +from app.core.constants import API_VERSION + +router = APIRouter(prefix=f"/gemini/{API_VERSION}") +router_v1beta = APIRouter(prefix=f"/{API_VERSION}") +logger = get_gemini_logger() + +security_service = SecurityService() +model_service = ModelService() + + +async def get_key_manager(): + """获取密钥管理器实例""" + return await get_key_manager_instance() + + +async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)): + """获取下一个可用的API密钥""" + return await key_manager.get_next_working_key() + + +async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取Gemini聊天服务实例""" + return GeminiChatService(settings.BASE_URL, key_manager) + + +@router.get("/models") +@router_v1beta.get("/models") +async def list_models( + _=Depends(security_service.verify_key_or_goog_api_key), + key_manager: KeyManager = Depends(get_key_manager) +): + """获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。""" + operation_name = "list_gemini_models" + logger.info("-" * 50 + operation_name + "-" * 50) + logger.info("Handling Gemini models list request") + + try: + api_key = await key_manager.get_first_valid_key() + if not api_key: + raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.") + logger.info(f"Using API key: {api_key}") + + models_data = await model_service.get_gemini_models(api_key) + if not models_data or "models" not in models_data: + raise HTTPException(status_code=500, detail="Failed to fetch base models list.") + + models_json = deepcopy(models_data) + model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])} + + def add_derived_model(base_name, suffix, display_suffix): + model = model_mapping.get(base_name) + if not model: + logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.") + return + item = deepcopy(model) + item["name"] = f"models/{base_name}{suffix}" + display_name = f'{item.get("displayName", base_name)}{display_suffix}' + item["displayName"] = display_name + item["description"] = display_name + models_json["models"].append(item) + + if settings.SEARCH_MODELS: + for name in settings.SEARCH_MODELS: + add_derived_model(name, "-search", " For Search") + if settings.IMAGE_MODELS: + for name in settings.IMAGE_MODELS: + add_derived_model(name, "-image", " For Image") + if settings.THINKING_MODELS: + for name in settings.THINKING_MODELS: + add_derived_model(name, "-non-thinking", " Non Thinking") + + logger.info("Gemini models list request successful") + return models_json + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.error(f"Error getting Gemini models list: {str(e)}") + raise HTTPException( + status_code=500, detail="Internal server error while fetching Gemini models list" + ) from e + + +@router.post("/models/{model_name}:generateContent") +@router_v1beta.post("/models/{model_name}:generateContent") +@RetryHandler(key_arg="api_key") +async def generate_content( + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + chat_service: GeminiChatService = Depends(get_chat_service) +): + """处理 Gemini 非流式内容生成请求。""" + operation_name = "gemini_generate_content" + async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"): + logger.info(f"Handling Gemini content generation request for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response = await chat_service.generate_content( + model=model_name, + request=request, + api_key=api_key + ) + return response + + +@router.post("/models/{model_name}:streamGenerateContent") +@router_v1beta.post("/models/{model_name}:streamGenerateContent") +@RetryHandler(key_arg="api_key") +async def stream_generate_content( + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + chat_service: GeminiChatService = Depends(get_chat_service) +): + """处理 Gemini 流式内容生成请求。""" + operation_name = "gemini_stream_generate_content" + async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"): + logger.info(f"Handling Gemini streaming content generation for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response_stream = chat_service.stream_generate_content( + model=model_name, + request=request, + api_key=api_key + ) + return StreamingResponse(response_stream, media_type="text/event-stream") + + +@router.post("/reset-all-fail-counts") +async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)): + """批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥""" + logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50) + logger.info(f"Received reset request with key_type: {key_type}") + + try: + # 获取分类后的密钥 + keys_by_status = await key_manager.get_keys_by_status() + valid_keys = keys_by_status.get("valid_keys", {}) + invalid_keys = keys_by_status.get("invalid_keys", {}) + + # 根据类型选择要重置的密钥 + keys_to_reset = [] + if key_type == "valid": + keys_to_reset = list(valid_keys.keys()) + logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}") + elif key_type == "invalid": + keys_to_reset = list(invalid_keys.keys()) + logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}") + else: + # 重置所有密钥 + await key_manager.reset_failure_counts() + return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"}) + + # 批量重置指定类型的密钥 + for key in keys_to_reset: + await key_manager.reset_key_failure_count(key) + + return JSONResponse({ + "success": True, + "message": f"{key_type}密钥的失败计数已重置", + "reset_count": len(keys_to_reset) + }) + except Exception as e: + logger.error(f"Failed to reset key failure counts: {str(e)}") + return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500) + + +@router.post("/reset-selected-fail-counts") +async def reset_selected_key_fail_counts( + request: ResetSelectedKeysRequest, + key_manager: KeyManager = Depends(get_key_manager) +): + """批量重置选定Gemini API密钥的失败计数""" + logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50) + keys_to_reset = request.keys + key_type = request.key_type + logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.") + + if not keys_to_reset: + return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400) + + reset_count = 0 + errors = [] + + try: + for key in keys_to_reset: + try: + result = await key_manager.reset_key_failure_count(key) + if result: + reset_count += 1 + else: + logger.warning(f"Key not found during selective reset: {key}") + except Exception as key_error: + logger.error(f"Error resetting key {key}: {str(key_error)}") + errors.append(f"Key {key}: {str(key_error)}") + + if errors: + error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}" + final_success = reset_count > 0 + status_code = 207 if final_success and errors else 500 + return JSONResponse({ + "success": final_success, + "message": error_message, + "reset_count": reset_count + }, status_code=status_code) + + return JSONResponse({ + "success": True, + "message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数", + "reset_count": reset_count + }) + except Exception as e: + logger.error(f"Failed to process reset selected key failure counts request: {str(e)}") + return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500) + + +@router.post("/reset-fail-count/{api_key}") +async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)): + """重置指定Gemini API密钥的失败计数""" + logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50) + logger.info(f"Resetting failure count for API key: {api_key}") + + try: + result = await key_manager.reset_key_failure_count(api_key) + if result: + return JSONResponse({"success": True, "message": "失败计数已重置"}) + return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404) + except Exception as e: + logger.error(f"Failed to reset key failure count: {str(e)}") + return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500) + + +@router.post("/verify-key/{api_key}") +async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)): + """验证Gemini API密钥的有效性""" + logger.info("-" * 50 + "verify_gemini_key" + "-" * 50) + logger.info("Verifying API key validity") + + try: + gemini_request = GeminiRequest( + contents=[ + GeminiContent( + role="user", + parts=[{"text": "hi"}], + ) + ], + generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10} + ) + + response = await chat_service.generate_content( + settings.TEST_MODEL, + gemini_request, + api_key + ) + + if response: + return JSONResponse({"status": "valid"}) + except Exception as e: + logger.error(f"Key verification failed: {str(e)}") + + async with key_manager.failure_count_lock: + if api_key in key_manager.key_failure_counts: + key_manager.key_failure_counts[api_key] += 1 + logger.warning(f"Verification exception for key: {api_key}, incrementing failure count") + + return JSONResponse({"status": "invalid", "error": str(e)}) + + +@router.post("/verify-selected-keys") +async def verify_selected_keys( + request: VerifySelectedKeysRequest, + chat_service: GeminiChatService = Depends(get_chat_service), + key_manager: KeyManager = Depends(get_key_manager) +): + """批量验证选定Gemini API密钥的有效性""" + logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50) + keys_to_verify = request.keys + logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.") + + if not keys_to_verify: + return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400) + + successful_keys = [] + failed_keys = {} + + async def _verify_single_key(api_key: str): + """内部函数,用于验证单个密钥并处理异常""" + nonlocal successful_keys, failed_keys + try: + gemini_request = GeminiRequest( + contents=[GeminiContent(role="user", parts=[{"text": "hi"}])], + generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10} + ) + await chat_service.generate_content( + settings.TEST_MODEL, + gemini_request, + api_key + ) + successful_keys.append(api_key) + return api_key, "valid", None + except Exception as e: + error_message = str(e) + logger.warning(f"Key verification failed for {api_key}: {error_message}") + async with key_manager.failure_count_lock: + if api_key in key_manager.key_failure_counts: + key_manager.key_failure_counts[api_key] += 1 + logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count") + else: + key_manager.key_failure_counts[api_key] = 1 + logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1") + failed_keys[api_key] = error_message + return api_key, "invalid", error_message + + tasks = [_verify_single_key(key) for key in keys_to_verify] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, Exception): + logger.error(f"An unexpected error occurred during bulk verification task: {result}") + elif result: + if not isinstance(result, Exception) and result: + key, status, error = result + elif isinstance(result, Exception): + logger.error(f"Task execution error during bulk verification: {result}") + + valid_count = len(successful_keys) + invalid_count = len(failed_keys) + logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}") + + if failed_keys: + message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。" + return JSONResponse({ + "success": True, + "message": message, + "successful_keys": successful_keys, + "failed_keys": failed_keys, + "valid_count": valid_count, + "invalid_count": invalid_count + }) + else: + message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。" + return JSONResponse({ + "success": True, + "message": message, + "successful_keys": successful_keys, + "failed_keys": {}, + "valid_count": valid_count, + "invalid_count": 0 + }) \ No newline at end of file diff --git a/app/router/openai_compatiable_routes.py b/app/router/openai_compatiable_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..68c4cbed03c51a1a0a6521b3f94105294a9e4f16 --- /dev/null +++ b/app/router/openai_compatiable_routes.py @@ -0,0 +1,113 @@ +from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse + +from app.config.config import settings +from app.core.security import SecurityService +from app.domain.openai_models import ( + ChatRequest, + EmbeddingRequest, + ImageGenerationRequest, +) +from app.handler.retry_handler import RetryHandler +from app.handler.error_handler import handle_route_errors +from app.log.logger import get_openai_compatible_logger +from app.service.key.key_manager import KeyManager, get_key_manager_instance +from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService + + +router = APIRouter() +logger = get_openai_compatible_logger() + +security_service = SecurityService() + +async def get_key_manager(): + return await get_key_manager_instance() + + +async def get_next_working_key_wrapper( + key_manager: KeyManager = Depends(get_key_manager), +): + return await key_manager.get_next_working_key() + + +async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取OpenAI聊天服务实例""" + return OpenAICompatiableService(settings.BASE_URL, key_manager) + + +@router.get("/openai/v1/models") +async def list_models( + _=Depends(security_service.verify_authorization), + key_manager: KeyManager = Depends(get_key_manager), + openai_service: OpenAICompatiableService = Depends(get_openai_service), +): + """获取可用模型列表。""" + operation_name = "list_models" + async with handle_route_errors(logger, operation_name): + logger.info("Handling models list request") + api_key = await key_manager.get_first_valid_key() + logger.info(f"Using API key: {api_key}") + return await openai_service.get_models(api_key) + + +@router.post("/openai/v1/chat/completions") +@RetryHandler(key_arg="api_key") +async def chat_completion( + request: ChatRequest, + _=Depends(security_service.verify_authorization), + api_key: str = Depends(get_next_working_key_wrapper), + key_manager: KeyManager = Depends(get_key_manager), + openai_service: OpenAICompatiableService = Depends(get_openai_service), +): + """处理聊天补全请求,支持流式响应和特定模型切换。""" + operation_name = "chat_completion" + is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat" + current_api_key = api_key + if is_image_chat: + current_api_key = await key_manager.get_paid_key() + + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling chat completion request for model: {request.model}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {current_api_key}") + + if is_image_chat: + response = await openai_service.create_image_chat_completion(request, current_api_key) + return response + else: + response = await openai_service.create_chat_completion(request, current_api_key) + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + return response + + +@router.post("/openai/v1/images/generations") +async def generate_image( + request: ImageGenerationRequest, + _=Depends(security_service.verify_authorization), + openai_service: OpenAICompatiableService = Depends(get_openai_service), +): + """处理图像生成请求。""" + operation_name = "generate_image" + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling image generation request for prompt: {request.prompt}") + request.model = settings.CREATE_IMAGE_MODEL + return await openai_service.generate_images(request) + + +@router.post("/openai/v1/embeddings") +async def embedding( + request: EmbeddingRequest, + _=Depends(security_service.verify_authorization), + key_manager: KeyManager = Depends(get_key_manager), + openai_service: OpenAICompatiableService = Depends(get_openai_service), +): + """处理文本嵌入请求。""" + operation_name = "embedding" + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling embedding request for model: {request.model}") + api_key = await key_manager.get_next_working_key() + logger.info(f"Using API key: {api_key}") + return await openai_service.create_embeddings( + input_text=request.input, model=request.model, api_key=api_key + ) diff --git a/app/router/openai_routes.py b/app/router/openai_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..6bee7c2757c014c1564ef82822ba88951288c59a --- /dev/null +++ b/app/router/openai_routes.py @@ -0,0 +1,175 @@ +from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi.responses import StreamingResponse + +from app.config.config import settings +from app.core.security import SecurityService +from app.domain.openai_models import ( + ChatRequest, + EmbeddingRequest, + ImageGenerationRequest, + TTSRequest, +) +from app.handler.retry_handler import RetryHandler +from app.handler.error_handler import handle_route_errors +from app.log.logger import get_openai_logger +from app.service.chat.openai_chat_service import OpenAIChatService +from app.service.embedding.embedding_service import EmbeddingService +from app.service.image.image_create_service import ImageCreateService +from app.service.tts.tts_service import TTSService +from app.service.key.key_manager import KeyManager, get_key_manager_instance +from app.service.model.model_service import ModelService + +router = APIRouter() +logger = get_openai_logger() + +security_service = SecurityService() +model_service = ModelService() +embedding_service = EmbeddingService() +image_create_service = ImageCreateService() +tts_service = TTSService() + + +async def get_key_manager(): + return await get_key_manager_instance() + + +async def get_next_working_key_wrapper( + key_manager: KeyManager = Depends(get_key_manager), +): + return await key_manager.get_next_working_key() + + +async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取OpenAI聊天服务实例""" + return OpenAIChatService(settings.BASE_URL, key_manager) + + +async def get_tts_service(): + """获取TTS服务实例""" + return tts_service + + +@router.get("/v1/models") +@router.get("/hf/v1/models") +async def list_models( + _=Depends(security_service.verify_authorization), + key_manager: KeyManager = Depends(get_key_manager), +): + """获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。""" + operation_name = "list_models" + async with handle_route_errors(logger, operation_name): + logger.info("Handling models list request") + api_key = await key_manager.get_first_valid_key() + logger.info(f"Using API key: {api_key}") + return await model_service.get_gemini_openai_models(api_key) + + +@router.post("/v1/chat/completions") +@router.post("/hf/v1/chat/completions") +@RetryHandler(key_arg="api_key") +async def chat_completion( + request: ChatRequest, + _=Depends(security_service.verify_authorization), + api_key: str = Depends(get_next_working_key_wrapper), + key_manager: KeyManager = Depends(get_key_manager), + chat_service: OpenAIChatService = Depends(get_openai_chat_service), +): + """处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。""" + operation_name = "chat_completion" + is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat" + current_api_key = api_key + if is_image_chat: + current_api_key = await key_manager.get_paid_key() + + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling chat completion request for model: {request.model}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {current_api_key}") + + if not await model_service.check_model_support(request.model): + raise HTTPException( + status_code=400, detail=f"Model {request.model} is not supported" + ) + + if is_image_chat: + response = await chat_service.create_image_chat_completion(request, current_api_key) + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + return response + else: + response = await chat_service.create_chat_completion(request, current_api_key) + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + return response + + +@router.post("/v1/images/generations") +@router.post("/hf/v1/images/generations") +async def generate_image( + request: ImageGenerationRequest, + _=Depends(security_service.verify_authorization), +): + """处理 OpenAI 图像生成请求。""" + operation_name = "generate_image" + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling image generation request for prompt: {request.prompt}") + response = image_create_service.generate_images(request) + return response + + +@router.post("/v1/embeddings") +@router.post("/hf/v1/embeddings") +async def embedding( + request: EmbeddingRequest, + _=Depends(security_service.verify_authorization), + key_manager: KeyManager = Depends(get_key_manager), +): + """处理 OpenAI 文本嵌入请求。""" + operation_name = "embedding" + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling embedding request for model: {request.model}") + api_key = await key_manager.get_next_working_key() + logger.info(f"Using API key: {api_key}") + response = await embedding_service.create_embedding( + input_text=request.input, model=request.model, api_key=api_key + ) + return response + + +@router.get("/v1/keys/list") +@router.get("/hf/v1/keys/list") +async def get_keys_list( + _=Depends(security_service.verify_auth_token), + key_manager: KeyManager = Depends(get_key_manager), +): + """获取有效和无效的API key列表 (需要管理 Token 认证)。""" + operation_name = "get_keys_list" + async with handle_route_errors(logger, operation_name): + logger.info("Handling keys list request") + keys_status = await key_manager.get_keys_by_status() + return { + "status": "success", + "data": { + "valid_keys": keys_status["valid_keys"], + "invalid_keys": keys_status["invalid_keys"], + }, + "total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]), + } + + +@router.post("/v1/audio/speech") +@router.post("/hf/v1/audio/speech") +async def text_to_speech( + request: TTSRequest, + _=Depends(security_service.verify_authorization), + api_key: str = Depends(get_next_working_key_wrapper), + tts_service: TTSService = Depends(get_tts_service), +): + """处理 OpenAI TTS 请求。""" + operation_name = "text_to_speech" + async with handle_route_errors(logger, operation_name): + logger.info(f"Handling TTS request for model: {request.model}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + audio_data = await tts_service.create_tts(request, api_key) + return Response(content=audio_data, media_type="audio/wav") diff --git a/app/router/routes.py b/app/router/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..46004d78c5df69f19f4a070a25f23304bf9716a1 --- /dev/null +++ b/app/router/routes.py @@ -0,0 +1,187 @@ +""" +路由配置模块,负责设置和配置应用程序的路由 +""" + +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.templating import Jinja2Templates + +from app.core.security import verify_auth_token +from app.log.logger import get_routes_logger +from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes +from app.service.key.key_manager import get_key_manager_instance +from app.service.stats.stats_service import StatsService + +logger = get_routes_logger() + +templates = Jinja2Templates(directory="app/templates") + + +def setup_routers(app: FastAPI) -> None: + """ + 设置应用程序的路由 + + Args: + app: FastAPI应用程序实例 + """ + app.include_router(openai_routes.router) + app.include_router(gemini_routes.router) + app.include_router(gemini_routes.router_v1beta) + app.include_router(config_routes.router) + app.include_router(error_log_routes.router) + app.include_router(scheduler_routes.router) + app.include_router(stats_routes.router) + app.include_router(version_routes.router) + app.include_router(openai_compatiable_routes.router) + app.include_router(vertex_express_routes.router) + + setup_page_routes(app) + + setup_health_routes(app) + setup_api_stats_routes(app) + + +def setup_page_routes(app: FastAPI) -> None: + """ + 设置页面相关的路由 + + Args: + app: FastAPI应用程序实例 + """ + + @app.get("/", response_class=HTMLResponse) + async def auth_page(request: Request): + """认证页面""" + return templates.TemplateResponse("auth.html", {"request": request}) + + @app.post("/auth") + async def authenticate(request: Request): + """处理认证请求""" + try: + form = await request.form() + auth_token = form.get("auth_token") + if not auth_token: + logger.warning("Authentication attempt with empty token") + return RedirectResponse(url="/", status_code=302) + + if verify_auth_token(auth_token): + logger.info("Successful authentication") + response = RedirectResponse(url="/config", status_code=302) + response.set_cookie( + key="auth_token", value=auth_token, httponly=True, max_age=3600 + ) + return response + logger.warning("Failed authentication attempt with invalid token") + return RedirectResponse(url="/", status_code=302) + except Exception as e: + logger.error(f"Authentication error: {str(e)}") + return RedirectResponse(url="/", status_code=302) + + @app.get("/keys", response_class=HTMLResponse) + async def keys_page(request: Request): + """密钥管理页面""" + try: + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to keys page") + return RedirectResponse(url="/", status_code=302) + + key_manager = await get_key_manager_instance() + keys_status = await key_manager.get_keys_by_status() + total_keys = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]) + valid_key_count = len(keys_status["valid_keys"]) + invalid_key_count = len(keys_status["invalid_keys"]) + + stats_service = StatsService() + api_stats = await stats_service.get_api_usage_stats() + logger.info(f"API stats retrieved: {api_stats}") + + logger.info(f"Keys status retrieved successfully. Total keys: {total_keys}") + return templates.TemplateResponse( + "keys_status.html", + { + "request": request, + "valid_keys": keys_status["valid_keys"], + "invalid_keys": keys_status["invalid_keys"], + "total_keys": total_keys, + "valid_key_count": valid_key_count, + "invalid_key_count": invalid_key_count, + "api_stats": api_stats, + }, + ) + except Exception as e: + logger.error(f"Error retrieving keys status or API stats: {str(e)}") + raise + + @app.get("/config", response_class=HTMLResponse) + async def config_page(request: Request): + """配置编辑页面""" + try: + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to config page") + return RedirectResponse(url="/", status_code=302) + + logger.info("Config page accessed successfully") + return templates.TemplateResponse("config_editor.html", {"request": request}) + except Exception as e: + logger.error(f"Error accessing config page: {str(e)}") + raise + + @app.get("/logs", response_class=HTMLResponse) + async def logs_page(request: Request): + """错误日志页面""" + try: + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to logs page") + return RedirectResponse(url="/", status_code=302) + + logger.info("Logs page accessed successfully") + return templates.TemplateResponse("error_logs.html", {"request": request}) + except Exception as e: + logger.error(f"Error accessing logs page: {str(e)}") + raise + + +def setup_health_routes(app: FastAPI) -> None: + """ + 设置健康检查相关的路由 + + Args: + app: FastAPI应用程序实例 + """ + + @app.get("/health") + async def health_check(request: Request): + """健康检查端点""" + logger.info("Health check endpoint called") + return {"status": "healthy"} + + +def setup_api_stats_routes(app: FastAPI) -> None: + """ + 设置 API 统计相关的路由 + + Args: + app: FastAPI应用程序实例 + """ + @app.get("/api/stats/details") + async def api_stats_details(request: Request, period: str): + """获取指定时间段内的 API 调用详情""" + try: + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to API stats details") + return {"error": "Unauthorized"}, 401 + + logger.info(f"Fetching API call details for period: {period}") + stats_service = StatsService() + details = await stats_service.get_api_call_details(period) + return details + except ValueError as e: + logger.warning(f"Invalid period requested for API stats details: {period} - {str(e)}") + return {"error": str(e)}, 400 + except Exception as e: + logger.error(f"Error fetching API stats details for period {period}: {str(e)}") + return {"error": "Internal server error"}, 500 diff --git a/app/router/scheduler_routes.py b/app/router/scheduler_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..b6618f34e59b7597d179bb765402f16bcd9d6e94 --- /dev/null +++ b/app/router/scheduler_routes.py @@ -0,0 +1,57 @@ +""" +定时任务控制路由模块 +""" + +from fastapi import APIRouter, Request, HTTPException, status +from fastapi.responses import JSONResponse + +from app.core.security import verify_auth_token +from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler +from app.log.logger import get_scheduler_routes + +logger = get_scheduler_routes() + +router = APIRouter( + prefix="/api/scheduler", + tags=["Scheduler"] +) + +async def verify_token(request: Request): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to scheduler API") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + +@router.post("/start", summary="启动定时任务") +async def start_scheduler_endpoint(request: Request): + """Start the background scheduler task""" + await verify_token(request) + try: + logger.info("Received request to start scheduler.") + start_scheduler() + return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK) + except Exception as e: + logger.error(f"Error starting scheduler: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to start scheduler: {str(e)}" + ) + +@router.post("/stop", summary="停止定时任务") +async def stop_scheduler_endpoint(request: Request): + """Stop the background scheduler task""" + await verify_token(request) + try: + logger.info("Received request to stop scheduler.") + stop_scheduler() + return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK) + except Exception as e: + logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to stop scheduler: {str(e)}" + ) \ No newline at end of file diff --git a/app/router/stats_routes.py b/app/router/stats_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..32658e9f428a0c46b6c89e44be2eb7e36a98e7a9 --- /dev/null +++ b/app/router/stats_routes.py @@ -0,0 +1,55 @@ +from fastapi import APIRouter, Depends, HTTPException, Request +from starlette import status +from app.core.security import verify_auth_token +from app.service.stats.stats_service import StatsService +from app.log.logger import get_stats_logger + +logger = get_stats_logger() + + +async def verify_token(request: Request): + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to scheduler API") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + +router = APIRouter( + prefix="/api", + tags=["stats"], + dependencies=[Depends(verify_token)] +) + +stats_service = StatsService() + +@router.get("/key-usage-details/{key}", + summary="获取指定密钥最近24小时的模型调用次数", + description="根据提供的 API 密钥,返回过去24小时内每个模型被调用的次数统计。") +async def get_key_usage_details(key: str): + """ + Retrieves the model usage count for a specific API key within the last 24 hours. + + Args: + key: The API key to get usage details for. + + Returns: + A dictionary with model names as keys and their call counts as values. + Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5} + + Raises: + HTTPException: If an error occurs during data retrieval. + """ + try: + usage_details = await stats_service.get_key_usage_details_last_24h(key) + if usage_details is None: + return {} + return usage_details + except Exception as e: + logger.error(f"Error fetching key usage details for key {key[:4]}...: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取密钥使用详情时出错: {e}" + ) \ No newline at end of file diff --git a/app/router/version_routes.py b/app/router/version_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..28925f3a434e5e3a64b18030add62642d0791bc5 --- /dev/null +++ b/app/router/version_routes.py @@ -0,0 +1,37 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field +from typing import Optional + +from app.service.update.update_service import check_for_updates +from app.utils.helpers import get_current_version +from app.log.logger import get_update_logger + +router = APIRouter(prefix="/api/version", tags=["Version"]) +logger = get_update_logger() + +class VersionInfo(BaseModel): + current_version: str = Field(..., description="当前应用程序版本") + latest_version: Optional[str] = Field(None, description="可用的最新版本") + update_available: bool = Field(False, description="是否有可用更新") + error_message: Optional[str] = Field(None, description="检查更新时发生的错误信息") + +@router.get("/check", response_model=VersionInfo, summary="检查应用程序更新") +async def get_version_info(): + """ + 检查当前应用程序版本与最新的 GitHub release 版本。 + """ + try: + current_version = get_current_version() + update_available, latest_version, error_message = await check_for_updates() + + logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'") + + return VersionInfo( + current_version=current_version, + latest_version=latest_version, + update_available=update_available, + error_message=error_message + ) + except Exception as e: + logger.error(f"Error in /api/version/check endpoint: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="检查版本信息时发生内部错误") \ No newline at end of file diff --git a/app/router/vertex_express_routes.py b/app/router/vertex_express_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7f0f86cc370fcdfb97a1b4553cdef9b18bf6ae --- /dev/null +++ b/app/router/vertex_express_routes.py @@ -0,0 +1,146 @@ +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from copy import deepcopy +from app.config.config import settings +from app.log.logger import get_vertex_express_logger +from app.core.security import SecurityService +from app.domain.gemini_models import GeminiRequest +from app.service.chat.vertex_express_chat_service import GeminiChatService +from app.service.key.key_manager import KeyManager, get_key_manager_instance +from app.service.model.model_service import ModelService +from app.handler.retry_handler import RetryHandler +from app.handler.error_handler import handle_route_errors +from app.core.constants import API_VERSION + +router = APIRouter(prefix=f"/vertex-express/{API_VERSION}") +logger = get_vertex_express_logger() + +security_service = SecurityService() +model_service = ModelService() + + +async def get_key_manager(): + """获取密钥管理器实例""" + return await get_key_manager_instance() + + +async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)): + """获取下一个可用的API密钥""" + return await key_manager.get_next_working_vertex_key() + + +async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)): + """获取Gemini聊天服务实例""" + return GeminiChatService(settings.VERTEX_EXPRESS_BASE_URL, key_manager) + + +@router.get("/models") +async def list_models( + _=Depends(security_service.verify_key_or_goog_api_key), + key_manager: KeyManager = Depends(get_key_manager) +): + """获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。""" + operation_name = "list_gemini_models" + logger.info("-" * 50 + operation_name + "-" * 50) + logger.info("Handling Gemini models list request") + + try: + api_key = await key_manager.get_first_valid_key() + if not api_key: + raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.") + logger.info(f"Using API key: {api_key}") + + models_data = await model_service.get_gemini_models(api_key) + if not models_data or "models" not in models_data: + raise HTTPException(status_code=500, detail="Failed to fetch base models list.") + + models_json = deepcopy(models_data) + model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])} + + def add_derived_model(base_name, suffix, display_suffix): + model = model_mapping.get(base_name) + if not model: + logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.") + return + item = deepcopy(model) + item["name"] = f"models/{base_name}{suffix}" + display_name = f'{item.get("displayName", base_name)}{display_suffix}' + item["displayName"] = display_name + item["description"] = display_name + models_json["models"].append(item) + + if settings.SEARCH_MODELS: + for name in settings.SEARCH_MODELS: + add_derived_model(name, "-search", " For Search") + if settings.IMAGE_MODELS: + for name in settings.IMAGE_MODELS: + add_derived_model(name, "-image", " For Image") + if settings.THINKING_MODELS: + for name in settings.THINKING_MODELS: + add_derived_model(name, "-non-thinking", " Non Thinking") + + logger.info("Gemini models list request successful") + return models_json + except HTTPException as http_exc: + raise http_exc + except Exception as e: + logger.error(f"Error getting Gemini models list: {str(e)}") + raise HTTPException( + status_code=500, detail="Internal server error while fetching Gemini models list" + ) from e + + +@router.post("/models/{model_name}:generateContent") +@RetryHandler(key_arg="api_key") +async def generate_content( + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + chat_service: GeminiChatService = Depends(get_chat_service) +): + """处理 Gemini 非流式内容生成请求。""" + operation_name = "gemini_generate_content" + async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"): + logger.info(f"Handling Gemini content generation request for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response = await chat_service.generate_content( + model=model_name, + request=request, + api_key=api_key + ) + return response + + +@router.post("/models/{model_name}:streamGenerateContent") +@RetryHandler(key_arg="api_key") +async def stream_generate_content( + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_key_or_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager), + chat_service: GeminiChatService = Depends(get_chat_service) +): + """处理 Gemini 流式内容生成请求。""" + operation_name = "gemini_stream_generate_content" + async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"): + logger.info(f"Handling Gemini streaming content generation for model: {model_name}") + logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + + if not await model_service.check_model_support(model_name): + raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + + response_stream = chat_service.stream_generate_content( + model=model_name, + request=request, + api_key=api_key + ) + return StreamingResponse(response_stream, media_type="text/event-stream") \ No newline at end of file diff --git a/app/scheduler/scheduled_tasks.py b/app/scheduler/scheduled_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..be57a23848d92bcdf061fd77ade50ca5c11e50b3 --- /dev/null +++ b/app/scheduler/scheduled_tasks.py @@ -0,0 +1,159 @@ + +from apscheduler.schedulers.asyncio import AsyncIOScheduler + +from app.config.config import settings +from app.domain.gemini_models import GeminiContent, GeminiRequest +from app.log.logger import Logger +from app.service.chat.gemini_chat_service import GeminiChatService +from app.service.error_log.error_log_service import delete_old_error_logs +from app.service.key.key_manager import get_key_manager_instance +from app.service.request_log.request_log_service import delete_old_request_logs_task + +logger = Logger.setup_logger("scheduler") + + +async def check_failed_keys(): + """ + 定时检查失败次数大于0的API密钥,并尝试验证它们。 + 如果验证成功,重置失败计数;如果失败,增加失败计数。 + """ + logger.info("Starting scheduled check for failed API keys...") + try: + key_manager = await get_key_manager_instance() + # 确保 KeyManager 已经初始化 + if not key_manager or not hasattr(key_manager, "key_failure_counts"): + logger.warning( + "KeyManager instance not available or not initialized. Skipping check." + ) + return + + # 创建 GeminiChatService 实例用于验证 + # 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务 + chat_service = GeminiChatService(settings.BASE_URL, key_manager) + + # 获取需要检查的 key 列表 (失败次数 > 0) + keys_to_check = [] + async with key_manager.failure_count_lock: # 访问共享数据需要加锁 + # 复制一份以避免在迭代时修改字典 + failure_counts_copy = key_manager.key_failure_counts.copy() + keys_to_check = [ + key for key, count in failure_counts_copy.items() if count > 0 + ] # 检查所有失败次数大于0的key + + if not keys_to_check: + logger.info("No keys with failure count > 0 found. Skipping verification.") + return + + logger.info( + f"Found {len(keys_to_check)} keys with failure count > 0 to verify." + ) + + for key in keys_to_check: + # 隐藏部分 key 用于日志记录 + log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key + logger.info(f"Verifying key: {log_key}...") + try: + # 构造测试请求 + gemini_request = GeminiRequest( + contents=[ + GeminiContent( + role="user", + parts=[{"text": "hi"}], + ) + ] + ) + await chat_service.generate_content( + settings.TEST_MODEL, gemini_request, key + ) + logger.info( + f"Key {log_key} verification successful. Resetting failure count." + ) + await key_manager.reset_key_failure_count(key) + except Exception as e: + logger.warning( + f"Key {log_key} verification failed: {str(e)}. Incrementing failure count." + ) + # 直接操作计数器,需要加锁 + async with key_manager.failure_count_lock: + # 再次检查 key 是否存在且失败次数未达上限 + if ( + key in key_manager.key_failure_counts + and key_manager.key_failure_counts[key] + < key_manager.MAX_FAILURES + ): + key_manager.key_failure_counts[key] += 1 + logger.info( + f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}." + ) + elif key in key_manager.key_failure_counts: + logger.warning( + f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further." + ) + + except Exception as e: + logger.error( + f"An error occurred during the scheduled key check: {str(e)}", exc_info=True + ) + + +def setup_scheduler(): + """设置并启动 APScheduler""" + scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区 + # 添加检查失败密钥的定时任务 + scheduler.add_job( + check_failed_keys, + "interval", + hours=settings.CHECK_INTERVAL_HOURS, + id="check_failed_keys_job", + name="Check Failed API Keys", + ) + logger.info( + f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)." + ) + + # 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行 + scheduler.add_job( + delete_old_error_logs, + "cron", + hour=3, + minute=0, + id="delete_old_error_logs_job", + name="Delete Old Error Logs", + ) + logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.") + + # 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行 + scheduler.add_job( + delete_old_request_logs_task, + "cron", + hour=3, + minute=5, + id="delete_old_request_logs_job", + name="Delete Old Request Logs", + ) + logger.info( + f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days." + ) + + scheduler.start() + logger.info("Scheduler started with all jobs.") + return scheduler + + +# 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止 +scheduler_instance = None + + +def start_scheduler(): + global scheduler_instance + if scheduler_instance is None or not scheduler_instance.running: + logger.info("Starting scheduler...") + scheduler_instance = setup_scheduler() + logger.info("Scheduler is already running.") + + +def stop_scheduler(): + global scheduler_instance + if scheduler_instance and scheduler_instance.running: + scheduler_instance.shutdown() + logger.info("Scheduler stopped.") diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py new file mode 100644 index 0000000000000000000000000000000000000000..cca82aea036f9f109c0ba0aa058a3b815771334c --- /dev/null +++ b/app/service/chat/gemini_chat_service.py @@ -0,0 +1,287 @@ +# app/services/chat_service.py + +import json +import re +import datetime +import time +from typing import Any, AsyncGenerator, Dict, List +from app.config.config import settings +from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS +from app.domain.gemini_models import GeminiRequest +from app.handler.response_handler import GeminiResponseHandler +from app.handler.stream_optimizer import gemini_optimizer +from app.log.logger import get_gemini_logger +from app.service.client.api_client import GeminiApiClient +from app.service.key.key_manager import KeyManager +from app.database.services import add_error_log, add_request_log + +logger = get_gemini_logger() + + +def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片部分""" + for content in contents: + if "parts" in content: + for part in content["parts"]: + if "image_url" in part or "inline_data" in part: + return True + return False + + +def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: + """构建工具""" + + def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]: + record = dict() + for item in tools: + if not item or not isinstance(item, dict): + continue + + for k, v in item.items(): + if k == "functionDeclarations" and v and isinstance(v, list): + functions = record.get("functionDeclarations", []) + functions.extend(v) + record["functionDeclarations"] = functions + else: + record[k] = v + return record + + tool = dict() + if payload and isinstance(payload, dict) and "tools" in payload: + if payload.get("tools") and isinstance(payload.get("tools"), dict): + payload["tools"] = [payload.get("tools")] + items = payload.get("tools", []) + if items and isinstance(items, list): + tool.update(_merge_tools(items)) + + if ( + settings.TOOLS_CODE_EXECUTION_ENABLED + and not (model.endswith("-search") or "-thinking" in model) + and not _has_image_parts(payload.get("contents", [])) + ): + tool["codeExecution"] = {} + if model.endswith("-search"): + tool["googleSearch"] = {} + + # 解决 "Tool use with function calling is unsupported" 问题 + if tool.get("functionDeclarations"): + tool.pop("googleSearch", None) + tool.pop("codeExecution", None) + + return [tool] if tool else [] + + +def _get_safety_settings(model: str) -> List[Dict[str, str]]: + """获取安全设置""" + if model == "gemini-2.0-flash-exp": + return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS + return settings.SAFETY_SETTINGS + + +def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: + """构建请求payload""" + request_dict = request.model_dump() + if request.generationConfig: + if request.generationConfig.maxOutputTokens is None: + # 如果未指定最大输出长度,则不传递该字段,解决截断的问题 + request_dict["generationConfig"].pop("maxOutputTokens") + + payload = { + "contents": request_dict.get("contents", []), + "tools": _build_tools(model, request_dict), + "safetySettings": _get_safety_settings(model), + "generationConfig": request_dict.get("generationConfig"), + "systemInstruction": request_dict.get("systemInstruction"), + } + + if model.endswith("-image") or model.endswith("-image-generation"): + payload.pop("systemInstruction") + payload["generationConfig"]["responseModalities"] = ["Text", "Image"] + + # 处理思考配置:优先使用客户端提供的配置,否则使用默认配置 + client_thinking_config = None + if request.generationConfig and request.generationConfig.thinkingConfig: + client_thinking_config = request.generationConfig.thinkingConfig + + if client_thinking_config is not None: + # 客户端提供了思考配置,直接使用 + payload["generationConfig"]["thinkingConfig"] = client_thinking_config + else: + # 客户端没有提供思考配置,使用默认配置 + if model.endswith("-non-thinking"): + payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0} + elif model in settings.THINKING_BUDGET_MAP: + payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)} + + return payload + + +class GeminiChatService: + """聊天服务""" + + def __init__(self, base_url: str, key_manager: KeyManager): + self.api_client = GeminiApiClient(base_url, settings.TIME_OUT) + self.key_manager = key_manager + self.response_handler = GeminiResponseHandler() + + def _extract_text_from_response(self, response: Dict[str, Any]) -> str: + """从响应中提取文本内容""" + if not response.get("candidates"): + return "" + + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + + if parts and "text" in parts[0]: + return parts[0].get("text", "") + return "" + + def _create_char_response( + self, original_response: Dict[str, Any], text: str + ) -> Dict[str, Any]: + """创建包含指定文本的响应""" + response_copy = json.loads(json.dumps(original_response)) + if response_copy.get("candidates") and response_copy["candidates"][0].get( + "content", {} + ).get("parts"): + response_copy["candidates"][0]["content"]["parts"][0]["text"] = text + return response_copy + + async def generate_content( + self, model: str, request: GeminiRequest, api_key: str + ) -> Dict[str, Any]: + """生成内容""" + payload = _build_payload(model, request) + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + + try: + response = await self.api_client.generate_content(payload, model, api_key) + is_success = True + status_code = 200 + return self.response_handler.handle_response(response, model, stream=False) + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Normal API call failed with error: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="gemini-chat-non-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) + + async def stream_generate_content( + self, model: str, request: GeminiRequest, api_key: str + ) -> AsyncGenerator[str, None]: + """流式生成内容""" + retries = 0 + max_retries = settings.MAX_RETRIES + payload = _build_payload(model, request) + is_success = False + status_code = None + final_api_key = api_key + + while retries < max_retries: + request_datetime = datetime.datetime.now() + start_time = time.perf_counter() + current_attempt_key = api_key + final_api_key = current_attempt_key + try: + async for line in self.api_client.stream_generate_content( + payload, model, current_attempt_key + ): + # print(line) + if line.startswith("data:"): + line = line[6:] + response_data = self.response_handler.handle_response( + json.loads(line), model, stream=True + ) + text = self._extract_text_from_response(response_data) + # 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理 + if text and settings.STREAM_OPTIMIZER_ENABLED: + # 使用流式输出优化器处理文本输出 + async for ( + optimized_chunk + ) in gemini_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_response(response_data, t), + lambda c: "data: " + json.dumps(c) + "\n\n", + ): + yield optimized_chunk + else: + # 如果没有文本内容(如工具调用等),整块输出 + yield "data: " + json.dumps(response_data) + "\n\n" + logger.info("Streaming completed successfully") + is_success = True + status_code = 200 + break + except Exception as e: + retries += 1 + is_success = False + error_log_msg = str(e) + logger.warning( + f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}" + ) + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=current_attempt_key, + model_name=model, + error_type="gemini-chat-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload + ) + + api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries) + if api_key: + logger.info(f"Switched to new API key: {api_key}") + else: + logger.error(f"No valid API key available after {retries} retries.") + break + + if retries >= max_retries: + logger.error( + f"Max retries ({max_retries}) reached for streaming." + ) + break + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=final_api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d2866a0e3b2ba8271f427e269cd046f3a6f0a1f3 --- /dev/null +++ b/app/service/chat/openai_chat_service.py @@ -0,0 +1,606 @@ +# app/services/chat_service.py + +import asyncio +import datetime +import json +import re +import time +from copy import deepcopy +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +from app.config.config import settings +from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS +from app.database.services import ( + add_error_log, + add_request_log, +) +from app.domain.openai_models import ChatRequest, ImageGenerationRequest +from app.handler.message_converter import OpenAIMessageConverter +from app.handler.response_handler import OpenAIResponseHandler +from app.handler.stream_optimizer import openai_optimizer +from app.log.logger import get_openai_logger +from app.service.client.api_client import GeminiApiClient +from app.service.image.image_create_service import ImageCreateService +from app.service.key.key_manager import KeyManager + +logger = get_openai_logger() + + +def _has_media_parts(contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片、音频或视频部分 (inline_data)""" + for content in contents: + if content and "parts" in content and isinstance(content["parts"], list): + for part in content["parts"]: + if isinstance(part, dict) and "inline_data" in part: + return True + return False + + +def _build_tools( + request: ChatRequest, messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: + """构建工具""" + tool = dict() + model = request.model + + if ( + settings.TOOLS_CODE_EXECUTION_ENABLED + and not ( + model.endswith("-search") + or "-thinking" in model + or model.endswith("-image") + or model.endswith("-image-generation") + ) + and not _has_media_parts(messages) + ): + tool["codeExecution"] = {} + logger.debug("Code execution tool enabled.") + elif _has_media_parts(messages): + logger.debug("Code execution tool disabled due to media parts presence.") + + if model.endswith("-search"): + tool["googleSearch"] = {} + + # 将 request 中的 tools 合并到 tools 中 + if request.tools: + function_declarations = [] + for item in request.tools: + if not item or not isinstance(item, dict): + continue + + if item.get("type", "") == "function" and item.get("function"): + function = deepcopy(item.get("function")) + parameters = function.get("parameters", {}) + if parameters.get("type") == "object" and not parameters.get( + "properties", {} + ): + function.pop("parameters", None) + + function_declarations.append(function) + + if function_declarations: + # 按照 function 的 name 去重 + names, functions = set(), [] + for fc in function_declarations: + if fc.get("name") not in names: + if fc.get("name")=="googleSearch": + # cherry开启内置搜索时,添加googleSearch工具 + tool["googleSearch"] = {} + else: + # 其他函数,添加到functionDeclarations中 + names.add(fc.get("name")) + functions.append(fc) + + tool["functionDeclarations"] = functions + + # 解决 "Tool use with function calling is unsupported" 问题 + if tool.get("functionDeclarations"): + tool.pop("googleSearch", None) + tool.pop("codeExecution", None) + + return [tool] if tool else [] + + +def _get_safety_settings(model: str) -> List[Dict[str, str]]: + """获取安全设置""" + # if ( + # "2.0" in model + # and "gemini-2.0-flash-thinking-exp" not in model + # and "gemini-2.0-pro-exp" not in model + # ): + if model == "gemini-2.0-flash-exp": + return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS + return settings.SAFETY_SETTINGS + + +def _build_payload( + request: ChatRequest, + messages: List[Dict[str, Any]], + instruction: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """构建请求payload""" + payload = { + "contents": messages, + "generationConfig": { + "temperature": request.temperature, + "stopSequences": request.stop, + "topP": request.top_p, + "topK": request.top_k, + }, + "tools": _build_tools(request, messages), + "safetySettings": _get_safety_settings(request.model), + } + if request.max_tokens is not None: + payload["generationConfig"]["maxOutputTokens"] = request.max_tokens + if request.model.endswith("-image") or request.model.endswith("-image-generation"): + payload["generationConfig"]["responseModalities"] = ["Text", "Image"] + if request.model.endswith("-non-thinking"): + payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0} + if request.model in settings.THINKING_BUDGET_MAP: + payload["generationConfig"]["thinkingConfig"] = { + "thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000) + } + + if ( + instruction + and isinstance(instruction, dict) + and instruction.get("role") == "system" + and instruction.get("parts") + and not request.model.endswith("-image") + and not request.model.endswith("-image-generation") + ): + payload["systemInstruction"] = instruction + + return payload + + +class OpenAIChatService: + """聊天服务""" + + def __init__(self, base_url: str, key_manager: KeyManager = None): + self.message_converter = OpenAIMessageConverter() + self.response_handler = OpenAIResponseHandler(config=None) + self.api_client = GeminiApiClient(base_url, settings.TIME_OUT) + self.key_manager = key_manager + self.image_create_service = ImageCreateService() + + def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str: + """从OpenAI响应块中提取文本内容""" + if not chunk.get("choices"): + return "" + + choice = chunk["choices"][0] + if "delta" in choice and "content" in choice["delta"]: + return choice["delta"]["content"] + return "" + + def _create_char_openai_chunk( + self, original_chunk: Dict[str, Any], text: str + ) -> Dict[str, Any]: + """创建包含指定文本的OpenAI响应块""" + chunk_copy = json.loads(json.dumps(original_chunk)) + if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]: + chunk_copy["choices"][0]["delta"]["content"] = text + return chunk_copy + + async def create_chat_completion( + self, + request: ChatRequest, + api_key: str, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """创建聊天完成""" + messages, instruction = self.message_converter.convert(request.messages) + + payload = _build_payload(request, messages, instruction) + + if request.stream: + return self._handle_stream_completion(request.model, payload, api_key) + return await self._handle_normal_completion(request.model, payload, api_key) + + async def _handle_normal_completion( + self, model: str, payload: Dict[str, Any], api_key: str + ) -> Dict[str, Any]: + """处理普通聊天完成""" + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + try: + response = await self.api_client.generate_content(payload, model, api_key) + usage_metadata = response.get("usageMetadata", {}) + is_success = True + status_code = 200 + return self.response_handler.handle_response( + response, + model, + stream=False, + finish_reason="stop", + usage_metadata=usage_metadata, + ) + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Normal API call failed with error: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="openai-chat-non-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload, + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) + + async def _fake_stream_logic_impl( + self, model: str, payload: Dict[str, Any], api_key: str + ) -> AsyncGenerator[str, None]: + """处理伪流式 (fake stream) 的核心逻辑""" + logger.info( + f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint." + ) + keep_sending_empty_data = True + + async def send_empty_data_locally() -> AsyncGenerator[str, None]: + """定期发送空数据以保持连接""" + while keep_sending_empty_data: + await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS) + if keep_sending_empty_data: + empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None) + yield f"data: {json.dumps(empty_chunk)}\n\n" + logger.debug("Sent empty data chunk for fake stream heartbeat.") + + empty_data_generator = send_empty_data_locally() + api_response_task = asyncio.create_task( + self.api_client.generate_content(payload, model, api_key) + ) + + try: + while not api_response_task.done(): + try: + next_empty_chunk = await asyncio.wait_for( + empty_data_generator.__anext__(), timeout=0.1 + ) + yield next_empty_chunk + except asyncio.TimeoutError: + pass + except ( + StopAsyncIteration + ): + break + + response = await api_response_task + finally: + keep_sending_empty_data = False + + if response and response.get("candidates"): + response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {})) + yield f"data: {json.dumps(response)}\n\n" + logger.info(f"Sent full response content for fake stream: {model}") + else: + error_message = "Failed to get response from model" + if ( + response and isinstance(response, dict) and response.get("error") + ): + error_details = response.get("error") + if isinstance(error_details, dict): + error_message = error_details.get("message", error_message) + + logger.error( + f"No candidates or error in response for fake stream model {model}: {response}" + ) + error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None) + yield f"data: {json.dumps(error_chunk)}\n\n" + + async def _real_stream_logic_impl( + self, model: str, payload: Dict[str, Any], api_key: str + ) -> AsyncGenerator[str, None]: + """处理真实流式 (real stream) 的核心逻辑""" + tool_call_flag = False + usage_metadata = None + async for line in self.api_client.stream_generate_content( + payload, model, api_key + ): + if line.startswith("data:"): + chunk_str = line[6:] + if not chunk_str or chunk_str.isspace(): + logger.debug( + f"Received empty data line for model {model}, skipping." + ) + continue + try: + chunk = json.loads(chunk_str) + usage_metadata = chunk.get("usageMetadata", {}) + except json.JSONDecodeError: + logger.error( + f"Failed to decode JSON from stream for model {model}: {chunk_str}" + ) + continue + openai_chunk = self.response_handler.handle_response( + chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata + ) + if openai_chunk: + text = self._extract_text_from_openai_chunk(openai_chunk) + if text and settings.STREAM_OPTIMIZER_ENABLED: + async for ( + optimized_chunk_data + ) in openai_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_openai_chunk(openai_chunk, t), + lambda c: f"data: {json.dumps(c)}\n\n", + ): + yield optimized_chunk_data + else: + if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"): + tool_call_flag = True + + yield f"data: {json.dumps(openai_chunk)}\n\n" + + if tool_call_flag: + yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n" + else: + yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n" + + async def _handle_stream_completion( + self, model: str, payload: Dict[str, Any], api_key: str + ) -> AsyncGenerator[str, None]: + """处理流式聊天完成,添加重试逻辑和假流式支持""" + retries = 0 + max_retries = settings.MAX_RETRIES + is_success = False + status_code = None + final_api_key = api_key + + while retries < max_retries: + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + current_attempt_key = final_api_key + + try: + stream_generator = None + if settings.FAKE_STREAM_ENABLED: + logger.info( + f"Using fake stream logic for model: {model}, Attempt: {retries + 1}" + ) + stream_generator = self._fake_stream_logic_impl( + model, payload, current_attempt_key + ) + else: + logger.info( + f"Using real stream logic for model: {model}, Attempt: {retries + 1}" + ) + stream_generator = self._real_stream_logic_impl( + model, payload, current_attempt_key + ) + + async for chunk_data in stream_generator: + yield chunk_data + + yield "data: [DONE]\n\n" + logger.info( + f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}" + ) + is_success = True + status_code = 200 + break + + except Exception as e: + retries += 1 + is_success = False + error_log_msg = str(e) + logger.warning( + f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}" + ) + + match = re.search(r"status code (\\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + if isinstance(e, asyncio.TimeoutError): + status_code = 408 + else: + status_code = 500 + + await add_error_log( + gemini_key=current_attempt_key, + model_name=model, + error_type="openai-chat-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload, + ) + + if self.key_manager: + new_api_key = await self.key_manager.handle_api_failure( + current_attempt_key, retries + ) + if new_api_key and new_api_key != current_attempt_key: + final_api_key = new_api_key + logger.info( + f"Switched to new API key for next attempt: {final_api_key}" + ) + elif not new_api_key: + logger.error( + f"No valid API key available after {retries} retries, ceasing attempts for this request." + ) + break + else: + logger.error( + "KeyManager not available, cannot switch API key. Ceasing attempts for this request." + ) + break + + if retries >= max_retries: + logger.error( + f"Max retries ({max_retries}) reached for streaming model {model}." + ) + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=current_attempt_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) + + if not is_success: + logger.error( + f"Streaming failed permanently for model {model} after {retries} attempts." + ) + yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n" + yield "data: [DONE]\n\n" + + async def create_image_chat_completion( + self, request: ChatRequest, api_key: str + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + + image_generate_request = ImageGenerationRequest() + image_generate_request.prompt = request.messages[-1]["content"] + image_res = self.image_create_service.generate_images_chat( + image_generate_request + ) + + if request.stream: + return self._handle_stream_image_completion( + request.model, image_res, api_key + ) + else: + return await self._handle_normal_image_completion( + request.model, image_res, api_key + ) + + async def _handle_stream_image_completion( + self, model: str, image_data: str, api_key: str + ) -> AsyncGenerator[str, None]: + logger.info(f"Starting stream image completion for model: {model}") + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + + try: + if image_data: + openai_chunk = self.response_handler.handle_image_chat_response( + image_data, model, stream=True, finish_reason=None + ) + if openai_chunk: + # 提取文本内容 + text = self._extract_text_from_openai_chunk(openai_chunk) + if text: + # 使用流式输出优化器处理文本输出 + async for ( + optimized_chunk + ) in openai_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_openai_chunk(openai_chunk, t), + lambda c: f"data: {json.dumps(c)}\n\n", + ): + yield optimized_chunk + else: + # 如果没有文本内容(如图片URL等),整块输出 + yield f"data: {json.dumps(openai_chunk)}\n\n" + yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" + logger.info( + f"Stream image completion finished successfully for model: {model}" + ) + is_success = True + status_code = 200 + yield "data: [DONE]\n\n" + except Exception as e: + is_success = False + error_log_msg = f"Stream image completion failed for model {model}: {e}" + logger.error(error_log_msg) + status_code = 500 + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="openai-image-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg={"image_data_truncated": image_data[:1000]}, + ) + yield f"data: {json.dumps({'error': error_log_msg})}\n\n" + yield "data: [DONE]\n\n" + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + logger.info( + f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}" + ) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) + + async def _handle_normal_image_completion( + self, model: str, image_data: str, api_key: str + ) -> Dict[str, Any]: + logger.info(f"Starting normal image completion for model: {model}") + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + result = None + + try: + result = self.response_handler.handle_image_chat_response( + image_data, model, stream=False, finish_reason="stop" + ) + logger.info( + f"Normal image completion finished successfully for model: {model}" + ) + is_success = True + status_code = 200 + return result + except Exception as e: + is_success = False + error_log_msg = f"Normal image completion failed for model {model}: {e}" + logger.error(error_log_msg) + status_code = 500 + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="openai-image-non-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg={"image_data_truncated": image_data[:1000]}, + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + logger.info( + f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}" + ) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) diff --git a/app/service/chat/vertex_express_chat_service.py b/app/service/chat/vertex_express_chat_service.py new file mode 100644 index 0000000000000000000000000000000000000000..313cb898ffbe06be481dccc12d3cd57f242de77c --- /dev/null +++ b/app/service/chat/vertex_express_chat_service.py @@ -0,0 +1,277 @@ +# app/services/chat_service.py + +import json +import re +import datetime +import time +from typing import Any, AsyncGenerator, Dict, List +from app.config.config import settings +from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS +from app.domain.gemini_models import GeminiRequest +from app.handler.response_handler import GeminiResponseHandler +from app.handler.stream_optimizer import gemini_optimizer +from app.log.logger import get_gemini_logger +from app.service.client.api_client import GeminiApiClient +from app.service.key.key_manager import KeyManager +from app.database.services import add_error_log, add_request_log + +logger = get_gemini_logger() + + +def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片部分""" + for content in contents: + if "parts" in content: + for part in content["parts"]: + if "image_url" in part or "inline_data" in part: + return True + return False + + +def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: + """构建工具""" + + def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]: + record = dict() + for item in tools: + if not item or not isinstance(item, dict): + continue + + for k, v in item.items(): + if k == "functionDeclarations" and v and isinstance(v, list): + functions = record.get("functionDeclarations", []) + functions.extend(v) + record["functionDeclarations"] = functions + else: + record[k] = v + return record + + tool = dict() + if payload and isinstance(payload, dict) and "tools" in payload: + if payload.get("tools") and isinstance(payload.get("tools"), dict): + payload["tools"] = [payload.get("tools")] + items = payload.get("tools", []) + if items and isinstance(items, list): + tool.update(_merge_tools(items)) + + if ( + settings.TOOLS_CODE_EXECUTION_ENABLED + and not (model.endswith("-search") or "-thinking" in model) + and not _has_image_parts(payload.get("contents", [])) + ): + tool["codeExecution"] = {} + if model.endswith("-search"): + tool["googleSearch"] = {} + + # 解决 "Tool use with function calling is unsupported" 问题 + if tool.get("functionDeclarations"): + tool.pop("googleSearch", None) + tool.pop("codeExecution", None) + + return [tool] if tool else [] + + +def _get_safety_settings(model: str) -> List[Dict[str, str]]: + """获取安全设置""" + if model == "gemini-2.0-flash-exp": + return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS + return settings.SAFETY_SETTINGS + + +def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: + """构建请求payload""" + request_dict = request.model_dump() + if request.generationConfig: + if request.generationConfig.maxOutputTokens is None: + # 如果未指定最大输出长度,则不传递该字段,解决截断的问题 + request_dict["generationConfig"].pop("maxOutputTokens") + + payload = { + "contents": request_dict.get("contents", []), + "tools": _build_tools(model, request_dict), + "safetySettings": _get_safety_settings(model), + "generationConfig": request_dict.get("generationConfig"), + "systemInstruction": request_dict.get("systemInstruction"), + } + + if model.endswith("-image") or model.endswith("-image-generation"): + payload.pop("systemInstruction") + payload["generationConfig"]["responseModalities"] = ["Text", "Image"] + + if model.endswith("-non-thinking"): + payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0} + if model in settings.THINKING_BUDGET_MAP: + payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)} + + return payload + + +class GeminiChatService: + """聊天服务""" + + def __init__(self, base_url: str, key_manager: KeyManager): + self.api_client = GeminiApiClient(base_url, settings.TIME_OUT) + self.key_manager = key_manager + self.response_handler = GeminiResponseHandler() + + def _extract_text_from_response(self, response: Dict[str, Any]) -> str: + """从响应中提取文本内容""" + if not response.get("candidates"): + return "" + + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + + if parts and "text" in parts[0]: + return parts[0].get("text", "") + return "" + + def _create_char_response( + self, original_response: Dict[str, Any], text: str + ) -> Dict[str, Any]: + """创建包含指定文本的响应""" + response_copy = json.loads(json.dumps(original_response)) # 深拷贝 + if response_copy.get("candidates") and response_copy["candidates"][0].get( + "content", {} + ).get("parts"): + response_copy["candidates"][0]["content"]["parts"][0]["text"] = text + return response_copy + + async def generate_content( + self, model: str, request: GeminiRequest, api_key: str + ) -> Dict[str, Any]: + """生成内容""" + payload = _build_payload(model, request) + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + + try: + response = await self.api_client.generate_content(payload, model, api_key) + is_success = True + status_code = 200 + return self.response_handler.handle_response(response, model, stream=False) + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Normal API call failed with error: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="gemini-chat-non-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) + + async def stream_generate_content( + self, model: str, request: GeminiRequest, api_key: str + ) -> AsyncGenerator[str, None]: + """流式生成内容""" + retries = 0 + max_retries = settings.MAX_RETRIES + payload = _build_payload(model, request) + is_success = False + status_code = None + final_api_key = api_key + + while retries < max_retries: + request_datetime = datetime.datetime.now() + start_time = time.perf_counter() + current_attempt_key = api_key + final_api_key = current_attempt_key # Update final key used + try: + async for line in self.api_client.stream_generate_content( + payload, model, current_attempt_key + ): + # print(line) + if line.startswith("data:"): + line = line[6:] + response_data = self.response_handler.handle_response( + json.loads(line), model, stream=True + ) + text = self._extract_text_from_response(response_data) + # 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理 + if text and settings.STREAM_OPTIMIZER_ENABLED: + # 使用流式输出优化器处理文本输出 + async for ( + optimized_chunk + ) in gemini_optimizer.optimize_stream_output( + text, + lambda t: self._create_char_response(response_data, t), + lambda c: "data: " + json.dumps(c) + "\n\n", + ): + yield optimized_chunk + else: + # 如果没有文本内容(如工具调用等),整块输出 + yield "data: " + json.dumps(response_data) + "\n\n" + logger.info("Streaming completed successfully") + is_success = True + status_code = 200 + break + except Exception as e: + retries += 1 + is_success = False + error_log_msg = str(e) + logger.warning( + f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}" + ) + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=current_attempt_key, + model_name=model, + error_type="gemini-chat-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload + ) + + api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries) + if api_key: + logger.info(f"Switched to new API key: {api_key}") + else: + logger.error(f"No valid API key available after {retries} retries.") + break + + if retries >= max_retries: + logger.error( + f"Max retries ({max_retries}) reached for streaming." + ) + break + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=final_api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) diff --git a/app/service/client/api_client.py b/app/service/client/api_client.py new file mode 100644 index 0000000000000000000000000000000000000000..10d439124299c17c3d5d6105f95a4cc3df698d48 --- /dev/null +++ b/app/service/client/api_client.py @@ -0,0 +1,222 @@ +# app/services/chat/api_client.py + +from typing import Dict, Any, AsyncGenerator, Optional +import httpx +import random +from abc import ABC, abstractmethod +from app.config.config import settings +from app.log.logger import get_api_client_logger +from app.core.constants import DEFAULT_TIMEOUT + +logger = get_api_client_logger() + +class ApiClient(ABC): + """API客户端基类""" + + @abstractmethod + async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + pass + + @abstractmethod + async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: + pass + + +class GeminiApiClient(ApiClient): + """Gemini API客户端""" + + def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT): + self.base_url = base_url + self.timeout = timeout + + def _get_real_model(self, model: str) -> str: + if model.endswith("-search"): + model = model[:-7] + if model.endswith("-image"): + model = model[:-6] + if model.endswith("-non-thinking"): + model = model[:-13] + if "-search" in model and "-non-thinking" in model: + model = model[:-20] + return model + + async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]: + """获取可用的 Gemini 模型列表""" + timeout = httpx.Timeout(timeout=5) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/models?key={api_key}&pageSize=1000" + try: + response = await client.get(url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + logger.error(f"获取模型列表失败: {e.response.status_code}") + logger.error(e.response.text) + return None + except httpx.RequestError as e: + logger.error(f"请求模型列表失败: {e}") + return None + + async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + model = self._get_real_model(model) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" + response = await client.post(url, json=payload) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() + + async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + model = self._get_real_model(model) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" + async with client.stream(method="POST", url=url, json=payload) as response: + if response.status_code != 200: + error_content = await response.aread() + error_msg = error_content.decode("utf-8") + raise Exception(f"API call failed with status code {response.status_code}, {error_msg}") + async for line in response.aiter_lines(): + yield line + + +class OpenaiApiClient(ApiClient): + """OpenAI API客户端""" + + def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT): + self.base_url = base_url + self.timeout = timeout + + async def get_models(self, api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/openai/models" + headers = {"Authorization": f"Bearer {api_key}"} + response = await client.get(url, headers=headers) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() + + async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}") + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/openai/chat/completions" + headers = {"Authorization": f"Bearer {api_key}"} + response = await client.post(url, json=payload, headers=headers) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() + + async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/openai/chat/completions" + headers = {"Authorization": f"Bearer {api_key}"} + async with client.stream(method="POST", url=url, json=payload, headers=headers) as response: + if response.status_code != 200: + error_content = await response.aread() + error_msg = error_content.decode("utf-8") + raise Exception(f"API call failed with status code {response.status_code}, {error_msg}") + async for line in response.aiter_lines(): + yield line + + async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/openai/embeddings" + headers = {"Authorization": f"Bearer {api_key}"} + payload = { + "input": input, + "model": model, + } + response = await client.post(url, json=payload, headers=headers) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() + + async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + + proxy_to_use = None + if settings.PROXIES: + if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: + proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)] + else: + proxy_to_use = random.choice(settings.PROXIES) + logger.info(f"Using proxy for getting models: {proxy_to_use}") + + async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client: + url = f"{self.base_url}/openai/images/generations" + headers = {"Authorization": f"Bearer {api_key}"} + response = await client.post(url, json=payload, headers=headers) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() \ No newline at end of file diff --git a/app/service/config/config_service.py b/app/service/config/config_service.py new file mode 100644 index 0000000000000000000000000000000000000000..10dfa21312bb44f00f576adff12f8c81f91a7ebd --- /dev/null +++ b/app/service/config/config_service.py @@ -0,0 +1,261 @@ +""" +配置服务模块 +""" + +import datetime +import json +from typing import Any, Dict, List + +from dotenv import find_dotenv, load_dotenv +from fastapi import HTTPException +from sqlalchemy import insert, update + +from app.config.config import Settings as ConfigSettings +from app.config.config import settings +from app.database.connection import database +from app.database.models import Settings +from app.database.services import get_all_settings +from app.log.logger import get_config_routes_logger +from app.service.key.key_manager import ( + get_key_manager_instance, + reset_key_manager_instance, +) +from app.service.model.model_service import ModelService + +logger = get_config_routes_logger() + + +class ConfigService: + """配置服务类,用于管理应用程序配置""" + + @staticmethod + async def get_config() -> Dict[str, Any]: + return settings.model_dump() + + @staticmethod + async def update_config(config_data: Dict[str, Any]) -> Dict[str, Any]: + for key, value in config_data.items(): + if hasattr(settings, key): + setattr(settings, key, value) + logger.debug(f"Updated setting in memory: {key}") + + # 获取现有设置 + existing_settings_raw: List[Dict[str, Any]] = await get_all_settings() + existing_settings_map: Dict[str, Dict[str, Any]] = { + s["key"]: s for s in existing_settings_raw + } + existing_keys = set(existing_settings_map.keys()) + + settings_to_update: List[Dict[str, Any]] = [] + settings_to_insert: List[Dict[str, Any]] = [] + now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8))) + + # 准备要更新或插入的数据 + for key, value in config_data.items(): + # 处理不同类型的值 + if isinstance(value, list): + db_value = json.dumps(value) + elif isinstance(value, dict): + db_value = json.dumps(value) + elif isinstance(value, bool): + db_value = str(value).lower() + else: + db_value = str(value) + + # 仅当值发生变化时才更新 + if key in existing_keys and existing_settings_map[key]["value"] == db_value: + continue + + description = f"{key}配置项" + + data = { + "key": key, + "value": db_value, + "description": description, + "updated_at": now, + } + + if key in existing_keys: + data["description"] = existing_settings_map[key].get( + "description", description + ) + settings_to_update.append(data) + else: + data["created_at"] = now + settings_to_insert.append(data) + + # 在事务中执行批量插入和更新 + if settings_to_insert or settings_to_update: + try: + async with database.transaction(): + if settings_to_insert: + query_insert = insert(Settings).values(settings_to_insert) + await database.execute(query=query_insert) + logger.info( + f"Bulk inserted {len(settings_to_insert)} settings." + ) + + if settings_to_update: + for setting_data in settings_to_update: + query_update = ( + update(Settings) + .where(Settings.key == setting_data["key"]) + .values( + value=setting_data["value"], + description=setting_data["description"], + updated_at=setting_data["updated_at"], + ) + ) + await database.execute(query=query_update) + logger.info(f"Updated {len(settings_to_update)} settings.") + except Exception as e: + logger.error(f"Failed to bulk update/insert settings: {str(e)}") + raise + + # 重置并重新初始化 KeyManager + try: + await reset_key_manager_instance() + await get_key_manager_instance(settings.API_KEYS, settings.VERTEX_API_KEYS) + logger.info("KeyManager instance re-initialized with updated settings.") + except Exception as e: + logger.error(f"Failed to re-initialize KeyManager: {str(e)}") + + return await ConfigService.get_config() + + @staticmethod + async def delete_key(key_to_delete: str) -> Dict[str, Any]: + """删除单个API密钥""" + # 确保 settings.API_KEYS 是一个列表 + if not isinstance(settings.API_KEYS, list): + settings.API_KEYS = [] + + original_keys_count = len(settings.API_KEYS) + # 创建一个不包含待删除密钥的新列表 + updated_api_keys = [k for k in settings.API_KEYS if k != key_to_delete] + + if len(updated_api_keys) < original_keys_count: + # 密钥已找到并从列表中移除 + settings.API_KEYS = updated_api_keys # 首先更新内存中的 settings + # 使用 update_config 持久化更改,它同时处理数据库和 KeyManager + await ConfigService.update_config({"API_KEYS": settings.API_KEYS}) + logger.info(f"密钥 '{key_to_delete}' 已成功删除。") + return {"success": True, "message": f"密钥 '{key_to_delete}' 已成功删除。"} + else: + # 未找到密钥 + logger.warning(f"尝试删除密钥 '{key_to_delete}',但未找到该密钥。") + return {"success": False, "message": f"未找到密钥 '{key_to_delete}'。"} + + @staticmethod + async def delete_selected_keys(keys_to_delete: List[str]) -> Dict[str, Any]: + """批量删除选定的API密钥""" + if not isinstance(settings.API_KEYS, list): + settings.API_KEYS = [] + + deleted_count = 0 + not_found_keys: List[str] = [] + + current_api_keys = list(settings.API_KEYS) + keys_actually_removed: List[str] = [] + + for key_to_del in keys_to_delete: + if key_to_del in current_api_keys: + current_api_keys.remove(key_to_del) + keys_actually_removed.append(key_to_del) + deleted_count += 1 + else: + not_found_keys.append(key_to_del) + + if deleted_count > 0: + settings.API_KEYS = current_api_keys + await ConfigService.update_config({"API_KEYS": settings.API_KEYS}) + logger.info( + f"成功删除 {deleted_count} 个密钥。密钥: {keys_actually_removed}" + ) + message = f"成功删除 {deleted_count} 个密钥。" + if not_found_keys: + message += f" {len(not_found_keys)} 个密钥未找到: {not_found_keys}。" + return { + "success": True, + "message": message, + "deleted_count": deleted_count, + "not_found_keys": not_found_keys, + } + else: + message = "没有密钥被删除。" + if not_found_keys: + message = f"所有 {len(not_found_keys)} 个指定的密钥均未找到: {not_found_keys}。" + elif not keys_to_delete: + message = "未指定要删除的密钥。" + logger.warning(message) + return { + "success": False, + "message": message, + "deleted_count": 0, + "not_found_keys": not_found_keys, + } + + @staticmethod + async def reset_config() -> Dict[str, Any]: + """ + 重置配置:优先从系统环境变量加载,然后从 .env 文件加载, + 更新内存中的 settings 对象,并刷新 KeyManager。 + + Returns: + Dict[str, Any]: 重置后的配置字典 + """ + # 1. 重新加载配置对象,它应该处理环境变量和 .env 的优先级 + _reload_settings() + logger.info( + "Settings object reloaded, prioritizing system environment variables then .env file." + ) + + # 2. 重置并重新初始化 KeyManager + try: + await reset_key_manager_instance() + # 确保使用更新后的 settings 中的 API_KEYS + await get_key_manager_instance(settings.API_KEYS) + logger.info("KeyManager instance re-initialized with reloaded settings.") + except Exception as e: + logger.error(f"Failed to re-initialize KeyManager during reset: {str(e)}") + # 根据需要决定是否抛出异常或继续 + # 这里选择记录错误并继续 + + # 3. 返回更新后的配置 + return await ConfigService.get_config() + + @staticmethod + async def fetch_ui_models() -> List[Dict[str, Any]]: + """获取用于UI显示的模型列表""" + try: + key_manager = await get_key_manager_instance() + model_service = ModelService() + + api_key = await key_manager.get_first_valid_key() + if not api_key: + logger.error("No valid API keys available to fetch model list for UI.") + raise HTTPException( + status_code=500, + detail="No valid API keys available to fetch model list.", + ) + + models = await model_service.get_gemini_openai_models(api_key) + return models + except HTTPException as e: + raise e + except Exception as e: + logger.error( + f"Failed to fetch models for UI in ConfigService: {e}", exc_info=True + ) + raise HTTPException( + status_code=500, detail=f"Failed to fetch models for UI: {str(e)}" + ) + + +# 重新加载配置的函数 +def _reload_settings(): + """重新加载环境变量并更新配置""" + # 显式加载 .env 文件,覆盖现有环境变量 + load_dotenv(find_dotenv(), override=True) + # 更新现有 settings 对象的属性,而不是新建实例 + for key, value in ConfigSettings().model_dump().items(): + setattr(settings, key, value) diff --git a/app/service/embedding/embedding_service.py b/app/service/embedding/embedding_service.py new file mode 100644 index 0000000000000000000000000000000000000000..43ad4d46f054d1c2c61d21a9a47dae0ce28f09ef --- /dev/null +++ b/app/service/embedding/embedding_service.py @@ -0,0 +1,78 @@ +import datetime +import time +import re +from typing import List, Union + +import openai +from openai import APIStatusError +from openai.types import CreateEmbeddingResponse + +from app.config.config import settings +from app.log.logger import get_embeddings_logger +from app.database.services import add_error_log, add_request_log + +logger = get_embeddings_logger() + + +class EmbeddingService: + + async def create_embedding( + self, input_text: Union[str, List[str]], model: str, api_key: str + ) -> CreateEmbeddingResponse: + """Create embeddings using OpenAI API with database logging""" + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + error_log_msg = "" + if isinstance(input_text, list): + request_msg_log = {"input_truncated": [str(item)[:100] + "..." if len(str(item)) > 100 else str(item) for item in input_text[:5]]} + if len(input_text) > 5: + request_msg_log["input_truncated"].append("...") + else: + request_msg_log = {"input_truncated": input_text[:1000] + "..." if len(input_text) > 1000 else input_text} + + + try: + client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL) + response = client.embeddings.create(input=input_text, model=model) + is_success = True + status_code = 200 + return response + except APIStatusError as e: + is_success = False + status_code = e.status_code + error_log_msg = f"OpenAI API error: {e}" + logger.error(f"Error creating embedding (APIStatusError): {error_log_msg}") + raise e + except Exception as e: + is_success = False + error_log_msg = f"Generic error: {e}" + logger.error(f"Error creating embedding (Exception): {error_log_msg}") + match = re.search(r"status code (\d+)", str(e)) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + if not is_success: + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="openai-embedding", + error_log=error_log_msg, + error_code=status_code, + request_msg=request_msg_log + ) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) diff --git a/app/service/error_log/error_log_service.py b/app/service/error_log/error_log_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd84c7e273318d6bf05f473c23e3a61addb9eaf --- /dev/null +++ b/app/service/error_log/error_log_service.py @@ -0,0 +1,178 @@ +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +from sqlalchemy import delete, func, select + +from app.config.config import settings +from app.database import services as db_services +from app.database.connection import database +from app.database.models import ErrorLog +from app.log.logger import get_error_log_logger + +logger = get_error_log_logger() + + +async def delete_old_error_logs(): + """ + Deletes error logs older than a specified number of days, + based on the AUTO_DELETE_ERROR_LOGS_ENABLED and AUTO_DELETE_ERROR_LOGS_DAYS settings. + """ + if not settings.AUTO_DELETE_ERROR_LOGS_ENABLED: + logger.info("Auto-deletion of error logs is disabled. Skipping.") + return + + days_to_keep = settings.AUTO_DELETE_ERROR_LOGS_DAYS + if not isinstance(days_to_keep, int) or days_to_keep <= 0: + logger.error( + f"Invalid AUTO_DELETE_ERROR_LOGS_DAYS value: {days_to_keep}. Must be a positive integer. Skipping deletion." + ) + return + + cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep) + + logger.info( + f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})." + ) + + try: + if not database.is_connected: + await database.connect() + logger.info("Database connection established for deleting error logs.") + + # First, count how many logs will be deleted (optional, for logging) + count_query = select(func.count(ErrorLog.id)).where( + ErrorLog.request_time < cutoff_date + ) + num_logs_to_delete = await database.fetch_val(count_query) + + if num_logs_to_delete == 0: + logger.info( + "No error logs found older than the specified period. No deletion needed." + ) + return + + logger.info(f"Found {num_logs_to_delete} error logs to delete.") + + # Perform the deletion + query = delete(ErrorLog).where(ErrorLog.request_time < cutoff_date) + await database.execute(query) + logger.info( + f"Successfully deleted {num_logs_to_delete} error logs older than {days_to_keep} days." + ) + + except Exception as e: + logger.error( + f"Error during automatic deletion of error logs: {e}", exc_info=True + ) + + +async def process_get_error_logs( + limit: int, + offset: int, + key_search: Optional[str], + error_search: Optional[str], + error_code_search: Optional[str], + start_date: Optional[datetime], + end_date: Optional[datetime], + sort_by: str, + sort_order: str, +) -> Dict[str, Any]: + """ + 处理错误日志的检索,支持分页和过滤。 + """ + try: + logs_data = await db_services.get_error_logs( + limit=limit, + offset=offset, + key_search=key_search, + error_search=error_search, + error_code_search=error_code_search, + start_date=start_date, + end_date=end_date, + sort_by=sort_by, + sort_order=sort_order, + ) + total_count = await db_services.get_error_logs_count( + key_search=key_search, + error_search=error_search, + error_code_search=error_code_search, + start_date=start_date, + end_date=end_date, + ) + return {"logs": logs_data, "total": total_count} + except Exception as e: + logger.error(f"Service error in process_get_error_logs: {e}", exc_info=True) + raise + + +async def process_get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]: + """ + 处理特定错误日志详细信息的检索。 + 如果未找到,则返回 None。 + """ + try: + log_details = await db_services.get_error_log_details(log_id=log_id) + return log_details + except Exception as e: + logger.error( + f"Service error in process_get_error_log_details for ID {log_id}: {e}", + exc_info=True, + ) + raise + + +async def process_delete_error_logs_by_ids(log_ids: List[int]) -> int: + """ + 按 ID 批量删除错误日志。 + 返回尝试删除的日志数量。 + """ + if not log_ids: + return 0 + try: + deleted_count = await db_services.delete_error_logs_by_ids(log_ids) + return deleted_count + except Exception as e: + logger.error( + f"Service error in process_delete_error_logs_by_ids for IDs {log_ids}: {e}", + exc_info=True, + ) + raise + + +async def process_delete_error_log_by_id(log_id: int) -> bool: + """ + 按 ID 删除单个错误日志。 + 如果删除成功(或找到日志并尝试删除),则返回 True,否则返回 False。 + """ + try: + success = await db_services.delete_error_log_by_id(log_id) + return success + except Exception as e: + logger.error( + f"Service error in process_delete_error_log_by_id for ID {log_id}: {e}", + exc_info=True, + ) + raise + + +async def process_delete_all_error_logs() -> int: + """ + 处理删除所有错误日志的请求。 + 返回删除的日志数量。 + """ + try: + if not database.is_connected: + await database.connect() + logger.info("Database connection established for deleting all error logs.") + + deleted_count = await db_services.delete_all_error_logs() + logger.info( + f"Successfully processed request to delete all error logs. Count: {deleted_count}" + ) + return deleted_count + except Exception as e: + logger.error( + f"Service error in process_delete_all_error_logs: {e}", + exc_info=True, + ) + raise diff --git a/app/service/image/image_create_service.py b/app/service/image/image_create_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ab02a105692ae80b025d2513ca445dd2985c94 --- /dev/null +++ b/app/service/image/image_create_service.py @@ -0,0 +1,162 @@ +import base64 +import time +import uuid + +from google import genai +from google.genai import types + +from app.config.config import settings +from app.core.constants import VALID_IMAGE_RATIOS +from app.domain.openai_models import ImageGenerationRequest +from app.log.logger import get_image_create_logger +from app.utils.uploader import ImageUploaderFactory + +logger = get_image_create_logger() + + +class ImageCreateService: + def __init__(self, aspect_ratio="1:1"): + self.image_model = settings.CREATE_IMAGE_MODEL + self.aspect_ratio = aspect_ratio + + def parse_prompt_parameters(self, prompt: str) -> tuple: + """从prompt中解析参数 + 支持的格式: + - {n:数量} 例如: {n:2} 生成2张图片 + - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 + """ + import re + + # 默认值 + n = 1 + aspect_ratio = self.aspect_ratio + + # 解析n参数 + n_match = re.search(r"{n:(\d+)}", prompt) + if n_match: + n = int(n_match.group(1)) + if n < 1 or n > 4: + raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") + prompt = prompt.replace(n_match.group(0), "").strip() + + # 解析ratio参数 + ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt) + if ratio_match: + aspect_ratio = ratio_match.group(1) + if aspect_ratio not in VALID_IMAGE_RATIOS: + raise ValueError( + f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" + ) + prompt = prompt.replace(ratio_match.group(0), "").strip() + + return prompt, n, aspect_ratio + + def generate_images(self, request: ImageGenerationRequest): + client = genai.Client(api_key=settings.PAID_KEY) + + if request.size == "1024x1024": + self.aspect_ratio = "1:1" + elif request.size == "1792x1024": + self.aspect_ratio = "16:9" + elif request.size == "1027x1792": + self.aspect_ratio = "9:16" + else: + raise ValueError( + f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792." + ) + + # 解析prompt中的参数 + cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters( + request.prompt + ) + request.prompt = cleaned_prompt + + # 如果prompt中指定了n,则覆盖请求中的n + if prompt_n > 1: + request.n = prompt_n + + # 如果prompt中指定了ratio,则覆盖默认的aspect_ratio + if prompt_ratio != self.aspect_ratio: + self.aspect_ratio = prompt_ratio + + response = client.models.generate_images( + model=self.image_model, + prompt=request.prompt, + config=types.GenerateImagesConfig( + number_of_images=request.n, + output_mime_type="image/png", + aspect_ratio=self.aspect_ratio, + safety_filter_level="BLOCK_LOW_AND_ABOVE", + person_generation="ALLOW_ADULT", + ), + ) + + if response.generated_images: + images_data = [] + for index, generated_image in enumerate(response.generated_images): + image_data = generated_image.image.image_bytes + image_uploader = None + + if request.response_format == "b64_json": + base64_image = base64.b64encode(image_data).decode("utf-8") + images_data.append( + {"b64_json": base64_image, "revised_prompt": request.prompt} + ) + else: + current_date = time.strftime("%Y/%m/%d") + filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" + + if settings.UPLOAD_PROVIDER == "smms": + image_uploader = ImageUploaderFactory.create( + provider=settings.UPLOAD_PROVIDER, + api_key=settings.SMMS_SECRET_TOKEN, + ) + elif settings.UPLOAD_PROVIDER == "picgo": + image_uploader = ImageUploaderFactory.create( + provider=settings.UPLOAD_PROVIDER, + api_key=settings.PICGO_API_KEY, + ) + elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed": + image_uploader = ImageUploaderFactory.create( + provider=settings.UPLOAD_PROVIDER, + base_url=settings.CLOUDFLARE_IMGBED_URL, + auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE, + ) + else: + raise ValueError( + f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}" + ) + + upload_response = image_uploader.upload(image_data, filename) + + images_data.append( + { + "url": f"{upload_response.data.url}", + "revised_prompt": request.prompt, + } + ) + + response_data = { + "created": int(time.time()), + "data": images_data, + } + return response_data + else: + raise Exception("I can't generate these images") + + def generate_images_chat(self, request: ImageGenerationRequest) -> str: + response = self.generate_images(request) + image_datas = response["data"] + if image_datas: + markdown_images = [] + for index, image_data in enumerate(image_datas): + if "url" in image_data: + markdown_images.append( + f"![Generated Image {index+1}]({image_data['url']})" + ) + else: + # 如果是base64格式,创建data URL + markdown_images.append( + f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})" + ) + return "\n".join(markdown_images) diff --git a/app/service/key/key_manager.py b/app/service/key/key_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..94b9ae675c945e46dc41004f6bebf01315b5c960 --- /dev/null +++ b/app/service/key/key_manager.py @@ -0,0 +1,463 @@ +import asyncio +from itertools import cycle +from typing import Dict, Union + +from app.config.config import settings +from app.log.logger import get_key_manager_logger + +logger = get_key_manager_logger() + + +class KeyManager: + def __init__(self, api_keys: list, vertex_api_keys: list): + self.api_keys = api_keys + self.vertex_api_keys = vertex_api_keys + self.key_cycle = cycle(api_keys) + self.vertex_key_cycle = cycle(vertex_api_keys) + self.key_cycle_lock = asyncio.Lock() + self.vertex_key_cycle_lock = asyncio.Lock() + self.failure_count_lock = asyncio.Lock() + self.vertex_failure_count_lock = asyncio.Lock() + self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys} + self.vertex_key_failure_counts: Dict[str, int] = { + key: 0 for key in vertex_api_keys + } + self.MAX_FAILURES = settings.MAX_FAILURES + self.paid_key = settings.PAID_KEY + + async def get_paid_key(self) -> str: + return self.paid_key + + async def get_next_key(self) -> str: + """获取下一个API key""" + async with self.key_cycle_lock: + return next(self.key_cycle) + + async def get_next_vertex_key(self) -> str: + """获取下一个 Vertex API key""" + async with self.vertex_key_cycle_lock: + return next(self.vertex_key_cycle) + + async def is_key_valid(self, key: str) -> bool: + """检查key是否有效""" + async with self.failure_count_lock: + return self.key_failure_counts[key] < self.MAX_FAILURES + + async def is_vertex_key_valid(self, key: str) -> bool: + """检查 Vertex key 是否有效""" + async with self.vertex_failure_count_lock: + return self.vertex_key_failure_counts[key] < self.MAX_FAILURES + + async def reset_failure_counts(self): + """重置所有key的失败计数""" + async with self.failure_count_lock: + for key in self.key_failure_counts: + self.key_failure_counts[key] = 0 + + async def reset_vertex_failure_counts(self): + """重置所有 Vertex key 的失败计数""" + async with self.vertex_failure_count_lock: + for key in self.vertex_key_failure_counts: + self.vertex_key_failure_counts[key] = 0 + + async def reset_key_failure_count(self, key: str) -> bool: + """重置指定key的失败计数""" + async with self.failure_count_lock: + if key in self.key_failure_counts: + self.key_failure_counts[key] = 0 + logger.info(f"Reset failure count for key: {key}") + return True + logger.warning( + f"Attempt to reset failure count for non-existent key: {key}" + ) + return False + + async def reset_vertex_key_failure_count(self, key: str) -> bool: + """重置指定 Vertex key 的失败计数""" + async with self.vertex_failure_count_lock: + if key in self.vertex_key_failure_counts: + self.vertex_key_failure_counts[key] = 0 + logger.info(f"Reset failure count for Vertex key: {key}") + return True + logger.warning( + f"Attempt to reset failure count for non-existent Vertex key: {key}" + ) + return False + + async def get_next_working_key(self) -> str: + """获取下一可用的API key""" + initial_key = await self.get_next_key() + current_key = initial_key + + while True: + if await self.is_key_valid(current_key): + return current_key + + current_key = await self.get_next_key() + if current_key == initial_key: + return current_key + + async def get_next_working_vertex_key(self) -> str: + """获取下一可用的 Vertex API key""" + initial_key = await self.get_next_vertex_key() + current_key = initial_key + + while True: + if await self.is_vertex_key_valid(current_key): + return current_key + + current_key = await self.get_next_vertex_key() + if current_key == initial_key: + return current_key + + async def handle_api_failure(self, api_key: str, retries: int) -> str: + """处理API调用失败""" + async with self.failure_count_lock: + self.key_failure_counts[api_key] += 1 + if self.key_failure_counts[api_key] >= self.MAX_FAILURES: + logger.warning( + f"API key {api_key} has failed {self.MAX_FAILURES} times" + ) + if retries < settings.MAX_RETRIES: + return await self.get_next_working_key() + else: + return "" + + async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str: + """处理 Vertex API 调用失败""" + async with self.vertex_failure_count_lock: + self.vertex_key_failure_counts[api_key] += 1 + if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES: + logger.warning( + f"Vertex API key {api_key} has failed {self.MAX_FAILURES} times" + ) + + def get_fail_count(self, key: str) -> int: + """获取指定密钥的失败次数""" + return self.key_failure_counts.get(key, 0) + + def get_vertex_fail_count(self, key: str) -> int: + """获取指定 Vertex 密钥的失败次数""" + return self.vertex_key_failure_counts.get(key, 0) + + async def get_keys_by_status(self) -> dict: + """获取分类后的API key列表,包括失败次数""" + valid_keys = {} + invalid_keys = {} + + async with self.failure_count_lock: + for key in self.api_keys: + fail_count = self.key_failure_counts[key] + if fail_count < self.MAX_FAILURES: + valid_keys[key] = fail_count + else: + invalid_keys[key] = fail_count + + return {"valid_keys": valid_keys, "invalid_keys": invalid_keys} + + async def get_vertex_keys_by_status(self) -> dict: + """获取分类后的 Vertex API key 列表,包括失败次数""" + valid_keys = {} + invalid_keys = {} + + async with self.vertex_failure_count_lock: + for key in self.vertex_api_keys: + fail_count = self.vertex_key_failure_counts[key] + if fail_count < self.MAX_FAILURES: + valid_keys[key] = fail_count + else: + invalid_keys[key] = fail_count + return {"valid_keys": valid_keys, "invalid_keys": invalid_keys} + + async def get_first_valid_key(self) -> str: + """获取第一个有效的API key""" + async with self.failure_count_lock: + for key in self.key_failure_counts: + if self.key_failure_counts[key] < self.MAX_FAILURES: + return key + if self.api_keys: + return self.api_keys[0] + if not self.api_keys: + logger.warning( + "API key list is empty, cannot get first valid key.") + return "" + return self.api_keys[0] + + +_singleton_instance = None +_singleton_lock = asyncio.Lock() +_preserved_failure_counts: Union[Dict[str, int], None] = None +_preserved_vertex_failure_counts: Union[Dict[str, int], None] = None +_preserved_old_api_keys_for_reset: Union[list, None] = None +_preserved_vertex_old_api_keys_for_reset: Union[list, None] = None +_preserved_next_key_in_cycle: Union[str, None] = None +_preserved_vertex_next_key_in_cycle: Union[str, None] = None + + +async def get_key_manager_instance( + api_keys: list = None, vertex_api_keys: list = None +) -> KeyManager: + """ + 获取 KeyManager 单例实例。 + + 如果尚未创建实例,将使用提供的 api_keys,vertex_api_keys 初始化 KeyManager。 + 如果已创建实例,则忽略 api_keys 参数,返回现有单例。 + 如果在重置后调用,会尝试恢复之前的状态(失败计数、循环位置)。 + """ + global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle + + async with _singleton_lock: + if _singleton_instance is None: + if api_keys is None: + raise ValueError( + "API keys are required to initialize or re-initialize the KeyManager instance." + ) + if vertex_api_keys is None: + raise ValueError( + "Vertex API keys are required to initialize or re-initialize the KeyManager instance." + ) + + if not api_keys: + logger.warning( + "Initializing KeyManager with an empty list of API keys." + ) + if not vertex_api_keys: + logger.warning( + "Initializing KeyManager with an empty list of Vertex API keys." + ) + + _singleton_instance = KeyManager(api_keys, vertex_api_keys) + logger.info( + f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex API keys." + ) + + # 1. 恢复失败计数 + if _preserved_failure_counts: + current_failure_counts = { + key: 0 for key in _singleton_instance.api_keys + } + for key, count in _preserved_failure_counts.items(): + if key in current_failure_counts: + current_failure_counts[key] = count + _singleton_instance.key_failure_counts = current_failure_counts + logger.info("Inherited failure counts for applicable keys.") + _preserved_failure_counts = None + + if _preserved_vertex_failure_counts: + current_vertex_failure_counts = { + key: 0 for key in _singleton_instance.vertex_api_keys + } + for key, count in _preserved_vertex_failure_counts.items(): + if key in current_vertex_failure_counts: + current_vertex_failure_counts[key] = count + _singleton_instance.vertex_key_failure_counts = ( + current_vertex_failure_counts + ) + logger.info( + "Inherited failure counts for applicable Vertex keys.") + _preserved_vertex_failure_counts = None + + # 2. 调整 key_cycle 的起始点 + start_key_for_new_cycle = None + if ( + _preserved_old_api_keys_for_reset + and _preserved_next_key_in_cycle + and _singleton_instance.api_keys + ): + try: + start_idx_in_old = _preserved_old_api_keys_for_reset.index( + _preserved_next_key_in_cycle + ) + + for i in range(len(_preserved_old_api_keys_for_reset)): + current_old_key_idx = (start_idx_in_old + i) % len( + _preserved_old_api_keys_for_reset + ) + key_candidate = _preserved_old_api_keys_for_reset[ + current_old_key_idx + ] + if key_candidate in _singleton_instance.api_keys: + start_key_for_new_cycle = key_candidate + break + except ValueError: + logger.warning( + f"Preserved next key '{_preserved_next_key_in_cycle}' not found in preserved old API keys. " + "New cycle will start from the beginning of the new list." + ) + except Exception as e: + logger.error( + f"Error determining start key for new cycle from preserved state: {e}. " + "New cycle will start from the beginning." + ) + + if start_key_for_new_cycle and _singleton_instance.api_keys: + try: + target_idx = _singleton_instance.api_keys.index( + start_key_for_new_cycle + ) + for _ in range(target_idx): + next(_singleton_instance.key_cycle) + logger.info( + f"Key cycle in new instance advanced. Next call to get_next_key() will yield: {start_key_for_new_cycle}" + ) + except ValueError: + logger.warning( + f"Determined start key '{start_key_for_new_cycle}' not found in new API keys during cycle advancement. " + "New cycle will start from the beginning." + ) + except StopIteration: + logger.error( + "StopIteration while advancing key cycle, implies empty new API key list previously missed." + ) + except Exception as e: + logger.error( + f"Error advancing new key cycle: {e}. Cycle will start from beginning." + ) + else: + if _singleton_instance.api_keys: + logger.info( + "New key cycle will start from the beginning of the new API key list (no specific start key determined or needed)." + ) + else: + logger.info( + "New key cycle not applicable as the new API key list is empty." + ) + + # 清理所有保存的状态 + _preserved_old_api_keys_for_reset = None + _preserved_next_key_in_cycle = None + + # 3. 调整 vertex_key_cycle 的起始点 + start_key_for_new_vertex_cycle = None + if ( + _preserved_vertex_old_api_keys_for_reset + and _preserved_vertex_next_key_in_cycle + and _singleton_instance.vertex_api_keys + ): + try: + start_idx_in_old = _preserved_vertex_old_api_keys_for_reset.index( + _preserved_vertex_next_key_in_cycle + ) + + for i in range(len(_preserved_vertex_old_api_keys_for_reset)): + current_old_key_idx = (start_idx_in_old + i) % len( + _preserved_vertex_old_api_keys_for_reset + ) + key_candidate = _preserved_vertex_old_api_keys_for_reset[ + current_old_key_idx + ] + if key_candidate in _singleton_instance.vertex_api_keys: + start_key_for_new_vertex_cycle = key_candidate + break + except ValueError: + logger.warning( + f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex API keys. " + "New cycle will start from the beginning of the new list." + ) + except Exception as e: + logger.error( + f"Error determining start key for new Vertex key cycle from preserved state: {e}. " + "New cycle will start from the beginning." + ) + + if start_key_for_new_vertex_cycle and _singleton_instance.vertex_api_keys: + try: + target_idx = _singleton_instance.vertex_api_keys.index( + start_key_for_new_vertex_cycle + ) + for _ in range(target_idx): + next(_singleton_instance.vertex_key_cycle) + logger.info( + f"Vertex key cycle in new instance advanced. Next call to get_next_vertex_key() will yield: {start_key_for_new_vertex_cycle}" + ) + except ValueError: + logger.warning( + f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex API keys during cycle advancement. " + "New cycle will start from the beginning." + ) + except StopIteration: + logger.error( + "StopIteration while advancing Vertex key cycle, implies empty new Vertex API key list previously missed." + ) + except Exception as e: + logger.error( + f"Error advancing new Vertex key cycle: {e}. Cycle will start from beginning." + ) + else: + if _singleton_instance.vertex_api_keys: + logger.info( + "New Vertex key cycle will start from the beginning of the new Vertex API key list (no specific start key determined or needed)." + ) + else: + logger.info( + "New Vertex key cycle not applicable as the new Vertex API key list is empty." + ) + + # 清理所有保存的状态 + _preserved_vertex_old_api_keys_for_reset = None + _preserved_vertex_next_key_in_cycle = None + + return _singleton_instance + + +async def reset_key_manager_instance(): + """ + 重置 KeyManager 单例实例。 + 将保存当前实例的状态(失败计数、旧 API keys、下一个 key 提示) + 以供下一次 get_key_manager_instance 调用时恢复。 + """ + global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle + async with _singleton_lock: + if _singleton_instance: + # 1. 保存失败计数 + _preserved_failure_counts = _singleton_instance.key_failure_counts.copy() + _preserved_vertex_failure_counts = _singleton_instance.vertex_key_failure_counts.copy() + + # 2. 保存旧的 API keys 列表 + _preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy() + _preserved_vertex_old_api_keys_for_reset = _singleton_instance.vertex_api_keys.copy() + + # 3. 保存 key_cycle 的下一个 key 提示 + try: + if _singleton_instance.api_keys: + _preserved_next_key_in_cycle = ( + await _singleton_instance.get_next_key() + ) + else: + _preserved_next_key_in_cycle = None + except StopIteration: + logger.warning( + "Could not preserve next key hint: key cycle was empty or exhausted in old instance." + ) + _preserved_next_key_in_cycle = None + except Exception as e: + logger.error( + f"Error preserving next key hint during reset: {e}") + _preserved_next_key_in_cycle = None + + # 4. 保存 vertex_key_cycle 的下一个 key 提示 + try: + if _singleton_instance.vertex_api_keys: + _preserved_vertex_next_key_in_cycle = ( + await _singleton_instance.get_next_vertex_key() + ) + else: + _preserved_vertex_next_key_in_cycle = None + except StopIteration: + logger.warning( + "Could not preserve next key hint: Vertex key cycle was empty or exhausted in old instance." + ) + _preserved_vertex_next_key_in_cycle = None + except Exception as e: + logger.error( + f"Error preserving next key hint during reset: {e}") + _preserved_vertex_next_key_in_cycle = None + + _singleton_instance = None + logger.info( + "KeyManager instance has been reset. State (failure counts, old keys, next key hint) preserved for next instantiation." + ) + else: + logger.info( + "KeyManager instance was not set (or already reset), no reset action performed." + ) diff --git a/app/service/model/model_service.py b/app/service/model/model_service.py new file mode 100644 index 0000000000000000000000000000000000000000..61929fd13f6c5d57cc90ae667015070c819c9690 --- /dev/null +++ b/app/service/model/model_service.py @@ -0,0 +1,92 @@ +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +from app.config.config import settings +from app.log.logger import get_model_logger +from app.service.client.api_client import GeminiApiClient + +logger = get_model_logger() + + +class ModelService: + async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: + api_client = GeminiApiClient(base_url=settings.BASE_URL) + gemini_models = await api_client.get_models(api_key) + + if gemini_models is None: + logger.error("从 API 客户端获取模型列表失败。") + return None + + try: + filtered_models_list = [] + for model in gemini_models.get("models", []): + model_id = model["name"].split("/")[-1] + if model_id not in settings.FILTERED_MODELS: + filtered_models_list.append(model) + else: + logger.debug(f"Filtered out model: {model_id}") + + gemini_models["models"] = filtered_models_list + return gemini_models + except Exception as e: + logger.error(f"处理模型列表时出错: {e}") + return None + + async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]: + """获取 Gemini 模型并转换为 OpenAI 格式""" + gemini_models = await self.get_gemini_models(api_key) + if gemini_models is None: + return None + + return await self.convert_to_openai_models_format(gemini_models) + + async def convert_to_openai_models_format( + self, gemini_models: Dict[str, Any] + ) -> Dict[str, Any]: + openai_format = {"object": "list", "data": [], "success": True} + + for model in gemini_models.get("models", []): + model_id = model["name"].split("/")[-1] + openai_model = { + "id": model_id, + "object": "model", + "created": int(datetime.now(timezone.utc).timestamp()), + "owned_by": "google", + "permission": [], + "root": model["name"], + "parent": None, + } + openai_format["data"].append(openai_model) + + if model_id in settings.SEARCH_MODELS: + search_model = openai_model.copy() + search_model["id"] = f"{model_id}-search" + openai_format["data"].append(search_model) + if model_id in settings.IMAGE_MODELS: + image_model = openai_model.copy() + image_model["id"] = f"{model_id}-image" + openai_format["data"].append(image_model) + if model_id in settings.THINKING_MODELS: + non_thinking_model = openai_model.copy() + non_thinking_model["id"] = f"{model_id}-non-thinking" + openai_format["data"].append(non_thinking_model) + + if settings.CREATE_IMAGE_MODEL: + image_model = openai_model.copy() + image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat" + openai_format["data"].append(image_model) + return openai_format + + async def check_model_support(self, model: str) -> bool: + if not model or not isinstance(model, str): + return False + + model = model.strip() + if model.endswith("-search"): + model = model[:-7] + return model in settings.SEARCH_MODELS + if model.endswith("-image"): + model = model[:-6] + return model in settings.IMAGE_MODELS + + return model not in settings.FILTERED_MODELS diff --git a/app/service/openai_compatiable/openai_compatiable_service.py b/app/service/openai_compatiable/openai_compatiable_service.py new file mode 100644 index 0000000000000000000000000000000000000000..51e062b8cfac9868ad46a9ef384eded6face4228 --- /dev/null +++ b/app/service/openai_compatiable/openai_compatiable_service.py @@ -0,0 +1,190 @@ + +import datetime +import json +import re +import time +from typing import Any, AsyncGenerator, Dict, Union + +from app.config.config import settings +from app.database.services import ( + add_error_log, + add_request_log, +) +from app.domain.openai_models import ChatRequest, ImageGenerationRequest +from app.service.client.api_client import OpenaiApiClient +from app.service.key.key_manager import KeyManager +from app.log.logger import get_openai_compatible_logger + +logger = get_openai_compatible_logger() + +class OpenAICompatiableService: + + def __init__(self, base_url: str, key_manager: KeyManager = None): + self.key_manager = key_manager + self.base_url = base_url + self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT) + + async def get_models(self, api_key: str) -> Dict[str, Any]: + return await self.api_client.get_models(api_key) + + async def create_chat_completion( + self, + request: ChatRequest, + api_key: str, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """创建聊天完成""" + request_dict = request.model_dump() + # 移除值为null的 + request_dict = {k: v for k, v in request_dict.items() if v is not None} + del request_dict["top_k"] # 删除top_k参数,目前不支持该参数 + if request.stream: + return self._handle_stream_completion(request.model, request_dict, api_key) + return await self._handle_normal_completion(request.model, request_dict, api_key) + + async def generate_images( + self, + request: ImageGenerationRequest, + ) -> Dict[str, Any]: + """生成图片""" + request_dict = request.model_dump() + # 移除值为null的 + request_dict = {k: v for k, v in request_dict.items() if v is not None} + api_key = settings.PAID_KEY + return await self.api_client.generate_images(request_dict, api_key) + + async def create_embeddings( + self, + input_text: str, + model: str, + api_key: str, + ) -> Dict[str, Any]: + """创建嵌入""" + return await self.api_client.create_embeddings(input_text, model, api_key) + + async def _handle_normal_completion( + self, model: str, request: dict, api_key: str + ) -> Dict[str, Any]: + """处理普通聊天完成""" + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + try: + response = await self.api_client.generate_content(request, api_key) + is_success = True + status_code = 200 + return response + except Exception as e: + is_success = False + error_log_msg = str(e) + logger.error(f"Normal API call failed with error: {error_log_msg}") + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=api_key, + model_name=model, + error_type="openai-compatiable-non-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=request, + ) + raise e + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) + + async def _handle_stream_completion( + self, model: str, payload: dict, api_key: str + ) -> AsyncGenerator[str, None]: + """处理流式聊天完成,添加重试逻辑""" + retries = 0 + max_retries = settings.MAX_RETRIES + is_success = False + status_code = None + final_api_key = api_key + + while retries < max_retries: + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + current_attempt_key = api_key + final_api_key = current_attempt_key + try: + async for line in self.api_client.stream_generate_content( + payload, current_attempt_key + ): + if line.startswith("data:"): + # print(line) + yield line + "\n\n" + logger.info("Streaming completed successfully") + is_success = True + status_code = 200 + break + except Exception as e: + retries += 1 + is_success = False + error_log_msg = str(e) + logger.warning( + f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}" + ) + match = re.search(r"status code (\d+)", error_log_msg) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + + await add_error_log( + gemini_key=current_attempt_key, + model_name=model, + error_type="openai-compatiable-stream", + error_log=error_log_msg, + error_code=status_code, + request_msg=payload, + ) + + if self.key_manager: + api_key = await self.key_manager.handle_api_failure( + current_attempt_key, retries + ) + if api_key: + logger.info(f"Switched to new API key: {api_key}") + else: + logger.error( + f"No valid API key available after {retries} retries." + ) + break + else: + logger.error("KeyManager not available for retry logic.") + break + + if retries >= max_retries: + logger.error(f"Max retries ({max_retries}) reached for streaming.") + break + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + await add_request_log( + model_name=model, + api_key=final_api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime, + ) + if not is_success and retries >= max_retries: + yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n" + yield "data: [DONE]\n\n" + + diff --git a/app/service/request_log/request_log_service.py b/app/service/request_log/request_log_service.py new file mode 100644 index 0000000000000000000000000000000000000000..235617301f805baf17fd17ea6692d215b36a2d9b --- /dev/null +++ b/app/service/request_log/request_log_service.py @@ -0,0 +1,50 @@ +""" +Service for request log operations. +""" + +from datetime import datetime, timedelta, timezone + +from sqlalchemy import delete + +from app.database.connection import database +from app.config.config import settings +from app.database.models import RequestLog +from app.log.logger import get_request_log_logger + +logger = get_request_log_logger() + + +async def delete_old_request_logs_task(): + """ + 定时删除旧的请求日志。 + """ + if not settings.AUTO_DELETE_REQUEST_LOGS_ENABLED: + logger.info( + "Auto-delete for request logs is disabled by settings. Skipping task." + ) + return + + days_to_keep = settings.AUTO_DELETE_REQUEST_LOGS_DAYS + logger.info( + f"Starting scheduled task to delete old request logs older than {days_to_keep} days." + ) + + try: + cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep) + + query = delete(RequestLog).where(RequestLog.request_time < cutoff_date) + + if not database.is_connected: + logger.info("Connecting to database for request log deletion.") + await database.connect() + + result = await database.execute(query) + logger.info( + f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result.rowcount if result else 'N/A'}" + ) + + except Exception as e: + logger.error( + f"An error occurred during the scheduled request log deletion: {str(e)}", + exc_info=True, + ) diff --git a/app/service/stats/stats_service.py b/app/service/stats/stats_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e371ad81db70969a75c09b338ba7ff9582c22657 --- /dev/null +++ b/app/service/stats/stats_service.py @@ -0,0 +1,255 @@ +# app/service/stats_service.py + +import datetime +from typing import Union + +from sqlalchemy import and_, case, func, or_, select + +from app.database.connection import database +from app.database.models import RequestLog +from app.log.logger import get_stats_logger + +logger = get_stats_logger() + + +class StatsService: + """Service class for handling statistics related operations.""" + + async def get_calls_in_last_seconds(self, seconds: int) -> dict[str, int]: + """获取过去 N 秒内的调用次数 (总数、成功、失败)""" + try: + cutoff_time = datetime.datetime.now() - datetime.timedelta(seconds=seconds) + query = select( + func.count(RequestLog.id).label("total"), + func.sum( + case( + ( + and_( + RequestLog.status_code >= 200, + RequestLog.status_code < 300, + ), + 1, + ), + else_=0, + ) + ).label("success"), + func.sum( + case( + ( + or_( + RequestLog.status_code < 200, + RequestLog.status_code >= 300, + ), + 1, + ), + (RequestLog.status_code is None, 1), + else_=0, + ) + ).label("failure"), + ).where(RequestLog.request_time >= cutoff_time) + result = await database.fetch_one(query) + if result: + return { + "total": result["total"] or 0, + "success": result["success"] or 0, + "failure": result["failure"] or 0, + } + return {"total": 0, "success": 0, "failure": 0} + except Exception as e: + logger.error(f"Failed to get calls in last {seconds} seconds: {e}") + return {"total": 0, "success": 0, "failure": 0} + + async def get_calls_in_last_minutes(self, minutes: int) -> dict[str, int]: + """获取过去 N 分钟内的调用次数 (总数、成功、失败)""" + return await self.get_calls_in_last_seconds(minutes * 60) + + async def get_calls_in_last_hours(self, hours: int) -> dict[str, int]: + """获取过去 N 小时内的调用次数 (总数、成功、失败)""" + return await self.get_calls_in_last_seconds(hours * 3600) + + async def get_calls_in_current_month(self) -> dict[str, int]: + """获取当前自然月内的调用次数 (总数、成功、失败)""" + try: + now = datetime.datetime.now() + start_of_month = now.replace( + day=1, hour=0, minute=0, second=0, microsecond=0 + ) + query = select( + func.count(RequestLog.id).label("total"), + func.sum( + case( + ( + and_( + RequestLog.status_code >= 200, + RequestLog.status_code < 300, + ), + 1, + ), + else_=0, + ) + ).label("success"), + func.sum( + case( + ( + or_( + RequestLog.status_code < 200, + RequestLog.status_code >= 300, + ), + 1, + ), + (RequestLog.status_code is None, 1), + else_=0, + ) + ).label("failure"), + ).where(RequestLog.request_time >= start_of_month) + result = await database.fetch_one(query) + if result: + return { + "total": result["total"] or 0, + "success": result["success"] or 0, + "failure": result["failure"] or 0, + } + return {"total": 0, "success": 0, "failure": 0} + except Exception as e: + logger.error(f"Failed to get calls in current month: {e}") + return {"total": 0, "success": 0, "failure": 0} + + async def get_api_usage_stats(self) -> dict: + """获取所有需要的 API 使用统计数据 (总数、成功、失败)""" + try: + stats_1m = await self.get_calls_in_last_minutes(1) + stats_1h = await self.get_calls_in_last_hours(1) + stats_24h = await self.get_calls_in_last_hours(24) + stats_month = await self.get_calls_in_current_month() + + return { + "calls_1m": stats_1m, + "calls_1h": stats_1h, + "calls_24h": stats_24h, + "calls_month": stats_month, + } + except Exception as e: + logger.error(f"Failed to get API usage stats: {e}") + default_stat = {"total": 0, "success": 0, "failure": 0} + return { + "calls_1m": default_stat.copy(), + "calls_1h": default_stat.copy(), + "calls_24h": default_stat.copy(), + "calls_month": default_stat.copy(), + } + + async def get_api_call_details(self, period: str) -> list[dict]: + """ + 获取指定时间段内的 API 调用详情 + + Args: + period: 时间段标识 ('1m', '1h', '24h') + + Returns: + 包含调用详情的字典列表,每个字典包含 timestamp, key, model, status + + Raises: + ValueError: 如果 period 无效 + """ + now = datetime.datetime.now() + if period == "1m": + start_time = now - datetime.timedelta(minutes=1) + elif period == "1h": + start_time = now - datetime.timedelta(hours=1) + elif period == "24h": + start_time = now - datetime.timedelta(hours=24) + else: + raise ValueError(f"无效的时间段标识: {period}") + + try: + query = ( + select( + RequestLog.request_time.label("timestamp"), + RequestLog.api_key.label("key"), + RequestLog.model_name.label("model"), + RequestLog.status_code, + ) + .where(RequestLog.request_time >= start_time) + .order_by(RequestLog.request_time.desc()) + ) + + results = await database.fetch_all(query) + + details = [] + for row in results: + status = "failure" + if row["status_code"] is not None: + status = "success" if 200 <= row["status_code"] < 300 else "failure" + details.append( + { + "timestamp": row[ + "timestamp" + ].isoformat(), + "key": row["key"], + "model": row["model"], + "status": status, + } + ) + logger.info( + f"Retrieved {len(details)} API call details for period '{period}'" + ) + return details + + except Exception as e: + logger.error( + f"Failed to get API call details for period '{period}': {e}") + raise + + async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]: + """ + 获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。 + + Args: + key: 要查询的 API 密钥。 + + Returns: + 一个字典,其中键是模型名称,值是调用次数。 + 如果查询出错或没有找到记录,可能返回 None 或空字典。 + Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5} + """ + logger.info( + f"Fetching usage details for key ending in ...{key[-4:]} for the last 24h." + ) + cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=24) + + try: + query = ( + select( + RequestLog.model_name, func.count( + RequestLog.id).label("call_count") + ) + .where( + RequestLog.api_key == key, + RequestLog.request_time >= cutoff_time, + RequestLog.model_name.isnot(None), + ) + .group_by(RequestLog.model_name) + .order_by(func.count(RequestLog.id).desc()) + ) + + results = await database.fetch_all(query) + + if not results: + logger.info( + f"No usage details found for key ending in ...{key[-4:]} in the last 24h." + ) + return {} + + usage_details = {row["model_name"]: row["call_count"] + for row in results} + logger.info( + f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}" + ) + return usage_details + + except Exception as e: + logger.error( + f"Failed to get key usage details for key ending in ...{key[-4:]}: {e}", + exc_info=True, + ) + raise diff --git a/app/service/tts/tts_service.py b/app/service/tts/tts_service.py new file mode 100644 index 0000000000000000000000000000000000000000..99a6074672fbc271703a9f2401b5547449a43be1 --- /dev/null +++ b/app/service/tts/tts_service.py @@ -0,0 +1,94 @@ +import datetime +import io +import re +import time +import wave +from typing import Optional + +from google import genai + +from app.config.config import settings +from app.database.services import add_error_log, add_request_log +from app.domain.openai_models import TTSRequest +from app.log.logger import get_openai_logger + +logger = get_openai_logger() + + +def _create_wav_file(audio_data: bytes) -> bytes: + """Creates a WAV file in memory from raw audio data.""" + with io.BytesIO() as wav_file: + with wave.open(wav_file, "wb") as wf: + wf.setnchannels(1) # Mono + wf.setsampwidth(2) # 16-bit + wf.setframerate(24000) # 24kHz sample rate + wf.writeframes(audio_data) + return wav_file.getvalue() + + +class TTSService: + async def create_tts(self, request: TTSRequest, api_key: str) -> Optional[bytes]: + """ + 使用 Google Gemini SDK 创建音频。 + """ + start_time = time.perf_counter() + request_datetime = datetime.datetime.now() + is_success = False + status_code = None + response = None + error_log_msg = "" + try: + client = genai.Client(api_key=api_key) + response =await client.aio.models.generate_content( + model=settings.TTS_MODEL, + contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}", + config={ + "response_modalities": ["Audio"], + "speech_config": { + "voice_config": { + "prebuilt_voice_config": { + "voice_name": settings.TTS_VOICE_NAME + } + } + }, + }, + ) + if ( + response.candidates + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].inline_data + ): + raw_audio_data = response.candidates[0].content.parts[0].inline_data.data + is_success = True + status_code = 200 + return _create_wav_file(raw_audio_data) + except Exception as e: + is_success = False + error_log_msg = f"Generic error: {e}" + logger.error(f"An error occurred in TTSService: {error_log_msg}") + match = re.search(r"status code (\d+)", str(e)) + if match: + status_code = int(match.group(1)) + else: + status_code = 500 + raise + finally: + end_time = time.perf_counter() + latency_ms = int((end_time - start_time) * 1000) + if not is_success: + await add_error_log( + gemini_key=api_key, + model_name=settings.TTS_MODEL, + error_type="google-tts", + error_log=error_log_msg, + error_code=status_code, + request_msg=request.input + ) + await add_request_log( + model_name=settings.TTS_MODEL, + api_key=api_key, + is_success=is_success, + status_code=status_code, + latency_ms=latency_ms, + request_time=request_datetime + ) \ No newline at end of file diff --git a/app/service/update/update_service.py b/app/service/update/update_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7ba981acacce0445bd14fc1b061f2d1cad717f --- /dev/null +++ b/app/service/update/update_service.py @@ -0,0 +1,95 @@ +import httpx +from packaging import version +from typing import Optional, Tuple + +from app.config.config import settings +from app.log.logger import get_update_logger + +logger = get_update_logger() + +VERSION_FILE_PATH = "VERSION" + +async def check_for_updates() -> Tuple[bool, Optional[str], Optional[str]]: + """ + 通过比较当前版本与最新的 GitHub release 来检查应用程序更新。 + + Returns: + Tuple[bool, Optional[str], Optional[str]]: 一个元组,包含: + - bool: 如果有可用更新则为 True,否则为 False。 + - Optional[str]: 如果有可用更新,则为最新的版本字符串,否则为 None。 + - Optional[str]: 如果检查失败,则为错误消息,否则为 None。 + """ + try: + with open(VERSION_FILE_PATH, 'r', encoding='utf-8') as f: + current_v = f.read().strip() + if not current_v: + logger.error(f"VERSION file ('{VERSION_FILE_PATH}') is empty.") + return False, None, f"VERSION file ('{VERSION_FILE_PATH}') is empty." + except FileNotFoundError: + logger.error(f"VERSION file not found at '{VERSION_FILE_PATH}'. Make sure it exists in the project root.") + return False, None, f"VERSION file not found at '{VERSION_FILE_PATH}'." + except IOError as e: + logger.error(f"Error reading VERSION file ('{VERSION_FILE_PATH}'): {e}") + return False, None, f"Error reading VERSION file ('{VERSION_FILE_PATH}')." + + logger.info(f"当前应用程序版本 (from {VERSION_FILE_PATH}): {current_v}") + + if not settings.GITHUB_REPO_OWNER or not settings.GITHUB_REPO_NAME or \ + settings.GITHUB_REPO_OWNER == "your_owner" or settings.GITHUB_REPO_NAME == "your_repo": + logger.warning("GitHub repository owner/name not configured in settings. Skipping update check.") + return False, None, "Update check skipped: Repository not configured in settings." + + github_api_url = f"https://api.github.com/repos/{settings.GITHUB_REPO_OWNER}/{settings.GITHUB_REPO_NAME}/releases/latest" + logger.debug(f"Checking for updates at URL: {github_api_url}") + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + headers = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": f"{settings.GITHUB_REPO_NAME}-UpdateChecker/1.0" + } + response = await client.get(github_api_url, headers=headers) + response.raise_for_status() + + latest_release = response.json() + latest_v_str = latest_release.get("tag_name") + + if not latest_v_str: + logger.warning("在最新的 GitHub release 响应中找不到 'tag_name'。") + return False, None, "无法从 GitHub 解析最新版本。" + + if latest_v_str.startswith('v'): + latest_v_str = latest_v_str[1:] + + logger.info(f"在 GitHub 上找到的最新版本: {latest_v_str}") + + # 比较版本 + current_version = version.parse(current_v) + latest_version = version.parse(latest_v_str) + + if latest_version > current_version: + logger.info(f"有可用更新: {current_v} -> {latest_v_str}") + return True, latest_v_str, None + else: + logger.info("应用程序已是最新版本。") + return False, None, None + + except httpx.HTTPStatusError as e: + logger.error(f"检查更新时发生 HTTP 错误: {e.response.status_code} - {e.response.text}") + # 避免向用户显示详细的错误文本 + error_msg = f"获取更新信息失败 (HTTP {e.response.status_code})。" + if e.response.status_code == 404: + error_msg += " 请检查仓库名称是否正确或仓库是否有发布版本。" + elif e.response.status_code == 403: + error_msg += " API 速率限制或权限问题。" + return False, None, error_msg + except httpx.RequestError as e: + logger.error(f"检查更新时发生网络错误: {e}") + return False, None, "更新检查期间发生网络错误。" + except version.InvalidVersion: + latest_v_str_for_log = latest_v_str if 'latest_v_str' in locals() else 'N/A' + logger.error(f"发现无效的版本格式。当前 (from {VERSION_FILE_PATH}): '{current_v}', 最新: '{latest_v_str_for_log}'") + return False, None, "遇到无效的版本格式。" + except Exception as e: + logger.error(f"更新检查期间发生意外错误: {e}", exc_info=True) + return False, None, "发生意外错误。" \ No newline at end of file diff --git a/app/static/icons/icon-192x192.png b/app/static/icons/icon-192x192.png new file mode 100644 index 0000000000000000000000000000000000000000..a3bc4655fbf7923b2092cc1e248d58aee52124f3 Binary files /dev/null and b/app/static/icons/icon-192x192.png differ diff --git a/app/static/icons/logo.png b/app/static/icons/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..3c2831b2c5e0470baf8f32c2b9dfa0f35d1cc291 Binary files /dev/null and b/app/static/icons/logo.png differ diff --git a/app/static/icons/logo1.png b/app/static/icons/logo1.png new file mode 100644 index 0000000000000000000000000000000000000000..685e96261e4c35217b95ad364f7eec1c291bf4ea Binary files /dev/null and b/app/static/icons/logo1.png differ diff --git a/app/static/js/config_editor.js b/app/static/js/config_editor.js new file mode 100644 index 0000000000000000000000000000000000000000..9c8de799fa2dd01be1ab9230104b32b00084c067 --- /dev/null +++ b/app/static/js/config_editor.js @@ -0,0 +1,2174 @@ +// Constants +const SENSITIVE_INPUT_CLASS = "sensitive-input"; +const ARRAY_ITEM_CLASS = "array-item"; +const ARRAY_INPUT_CLASS = "array-input"; +const MAP_ITEM_CLASS = "map-item"; +const MAP_KEY_INPUT_CLASS = "map-key-input"; +const MAP_VALUE_INPUT_CLASS = "map-value-input"; +const SAFETY_SETTING_ITEM_CLASS = "safety-setting-item"; +const SHOW_CLASS = "show"; // For modals +const API_KEY_REGEX = /AIzaSy\S{33}/g; +const PROXY_REGEX = + /(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g; +const VERTEX_API_KEY_REGEX = /AQ\.[a-zA-Z0-9_]{50}/g; // 新增 Vertex API Key 正则 +const MASKED_VALUE = "••••••••"; + +// DOM Elements - Global Scope for frequently accessed elements +const safetySettingsContainer = document.getElementById( + "SAFETY_SETTINGS_container" +); +const thinkingModelsContainer = document.getElementById( + "THINKING_MODELS_container" +); +const apiKeyModal = document.getElementById("apiKeyModal"); +const apiKeyBulkInput = document.getElementById("apiKeyBulkInput"); +const apiKeySearchInput = document.getElementById("apiKeySearchInput"); +const bulkDeleteApiKeyModal = document.getElementById("bulkDeleteApiKeyModal"); +const bulkDeleteApiKeyInput = document.getElementById("bulkDeleteApiKeyInput"); +const proxyModal = document.getElementById("proxyModal"); +const proxyBulkInput = document.getElementById("proxyBulkInput"); +const bulkDeleteProxyModal = document.getElementById("bulkDeleteProxyModal"); +const bulkDeleteProxyInput = document.getElementById("bulkDeleteProxyInput"); +const resetConfirmModal = document.getElementById("resetConfirmModal"); +const configForm = document.getElementById("configForm"); // Added for frequent use + +// Vertex API Key Modal Elements +const vertexApiKeyModal = document.getElementById("vertexApiKeyModal"); +const vertexApiKeyBulkInput = document.getElementById("vertexApiKeyBulkInput"); +const bulkDeleteVertexApiKeyModal = document.getElementById( + "bulkDeleteVertexApiKeyModal" +); +const bulkDeleteVertexApiKeyInput = document.getElementById( + "bulkDeleteVertexApiKeyInput" +); + +// Model Helper Modal Elements +const modelHelperModal = document.getElementById("modelHelperModal"); +const modelHelperTitleElement = document.getElementById("modelHelperTitle"); +const modelHelperSearchInput = document.getElementById( + "modelHelperSearchInput" +); +const modelHelperListContainer = document.getElementById( + "modelHelperListContainer" +); +const closeModelHelperModalBtn = document.getElementById( + "closeModelHelperModalBtn" +); +const cancelModelHelperBtn = document.getElementById("cancelModelHelperBtn"); + +let cachedModelsList = null; +let currentModelHelperTarget = null; // { type: 'input'/'array', target: elementOrIdOrKey } + +// Modal Control Functions +function openModal(modalElement) { + if (modalElement) { + modalElement.classList.add(SHOW_CLASS); + } +} + +function closeModal(modalElement) { + if (modalElement) { + modalElement.classList.remove(SHOW_CLASS); + } +} + +document.addEventListener("DOMContentLoaded", function () { + // Initialize configuration + initConfig(); + + // Tab switching + const tabButtons = document.querySelectorAll(".tab-btn"); + tabButtons.forEach((button) => { + button.addEventListener("click", function (e) { + e.stopPropagation(); + const tabId = this.getAttribute("data-tab"); + switchTab(tabId); + }); + }); + + // Upload provider switching + const uploadProviderSelect = document.getElementById("UPLOAD_PROVIDER"); + if (uploadProviderSelect) { + uploadProviderSelect.addEventListener("change", function () { + toggleProviderConfig(this.value); + }); + } + + // Toggle switch events + const toggleSwitches = document.querySelectorAll(".toggle-switch"); + toggleSwitches.forEach((toggleSwitch) => { + toggleSwitch.addEventListener("click", function (e) { + e.stopPropagation(); + const checkbox = this.querySelector('input[type="checkbox"]'); + if (checkbox) { + checkbox.checked = !checkbox.checked; + } + }); + }); + + // Save button + const saveBtn = document.getElementById("saveBtn"); + if (saveBtn) { + saveBtn.addEventListener("click", saveConfig); + } + + // Reset button + const resetBtn = document.getElementById("resetBtn"); + if (resetBtn) { + resetBtn.addEventListener("click", resetConfig); // resetConfig will open the modal + } + + // Scroll buttons + window.addEventListener("scroll", toggleScrollButtons); + + // API Key Modal Elements and Events + const addApiKeyBtn = document.getElementById("addApiKeyBtn"); + const closeApiKeyModalBtn = document.getElementById("closeApiKeyModalBtn"); + const cancelAddApiKeyBtn = document.getElementById("cancelAddApiKeyBtn"); + const confirmAddApiKeyBtn = document.getElementById("confirmAddApiKeyBtn"); + + if (addApiKeyBtn) { + addApiKeyBtn.addEventListener("click", () => { + openModal(apiKeyModal); + if (apiKeyBulkInput) apiKeyBulkInput.value = ""; + }); + } + if (closeApiKeyModalBtn) + closeApiKeyModalBtn.addEventListener("click", () => + closeModal(apiKeyModal) + ); + if (cancelAddApiKeyBtn) + cancelAddApiKeyBtn.addEventListener("click", () => closeModal(apiKeyModal)); + if (confirmAddApiKeyBtn) + confirmAddApiKeyBtn.addEventListener("click", handleBulkAddApiKeys); + if (apiKeySearchInput) + apiKeySearchInput.addEventListener("input", handleApiKeySearch); + + // Bulk Delete API Key Modal Elements and Events + const bulkDeleteApiKeyBtn = document.getElementById("bulkDeleteApiKeyBtn"); + const closeBulkDeleteModalBtn = document.getElementById( + "closeBulkDeleteModalBtn" + ); + const cancelBulkDeleteApiKeyBtn = document.getElementById( + "cancelBulkDeleteApiKeyBtn" + ); + const confirmBulkDeleteApiKeyBtn = document.getElementById( + "confirmBulkDeleteApiKeyBtn" + ); + + if (bulkDeleteApiKeyBtn) { + bulkDeleteApiKeyBtn.addEventListener("click", () => { + openModal(bulkDeleteApiKeyModal); + if (bulkDeleteApiKeyInput) bulkDeleteApiKeyInput.value = ""; + }); + } + if (closeBulkDeleteModalBtn) + closeBulkDeleteModalBtn.addEventListener("click", () => + closeModal(bulkDeleteApiKeyModal) + ); + if (cancelBulkDeleteApiKeyBtn) + cancelBulkDeleteApiKeyBtn.addEventListener("click", () => + closeModal(bulkDeleteApiKeyModal) + ); + if (confirmBulkDeleteApiKeyBtn) + confirmBulkDeleteApiKeyBtn.addEventListener( + "click", + handleBulkDeleteApiKeys + ); + + // Proxy Modal Elements and Events + const addProxyBtn = document.getElementById("addProxyBtn"); + const closeProxyModalBtn = document.getElementById("closeProxyModalBtn"); + const cancelAddProxyBtn = document.getElementById("cancelAddProxyBtn"); + const confirmAddProxyBtn = document.getElementById("confirmAddProxyBtn"); + + if (addProxyBtn) { + addProxyBtn.addEventListener("click", () => { + openModal(proxyModal); + if (proxyBulkInput) proxyBulkInput.value = ""; + }); + } + if (closeProxyModalBtn) + closeProxyModalBtn.addEventListener("click", () => closeModal(proxyModal)); + if (cancelAddProxyBtn) + cancelAddProxyBtn.addEventListener("click", () => closeModal(proxyModal)); + if (confirmAddProxyBtn) + confirmAddProxyBtn.addEventListener("click", handleBulkAddProxies); + + // Bulk Delete Proxy Modal Elements and Events + const bulkDeleteProxyBtn = document.getElementById("bulkDeleteProxyBtn"); + const closeBulkDeleteProxyModalBtn = document.getElementById( + "closeBulkDeleteProxyModalBtn" + ); + const cancelBulkDeleteProxyBtn = document.getElementById( + "cancelBulkDeleteProxyBtn" + ); + const confirmBulkDeleteProxyBtn = document.getElementById( + "confirmBulkDeleteProxyBtn" + ); + + if (bulkDeleteProxyBtn) { + bulkDeleteProxyBtn.addEventListener("click", () => { + openModal(bulkDeleteProxyModal); + if (bulkDeleteProxyInput) bulkDeleteProxyInput.value = ""; + }); + } + if (closeBulkDeleteProxyModalBtn) + closeBulkDeleteProxyModalBtn.addEventListener("click", () => + closeModal(bulkDeleteProxyModal) + ); + if (cancelBulkDeleteProxyBtn) + cancelBulkDeleteProxyBtn.addEventListener("click", () => + closeModal(bulkDeleteProxyModal) + ); + if (confirmBulkDeleteProxyBtn) + confirmBulkDeleteProxyBtn.addEventListener( + "click", + handleBulkDeleteProxies + ); + + // Reset Confirmation Modal Elements and Events + const closeResetModalBtn = document.getElementById("closeResetModalBtn"); + const cancelResetBtn = document.getElementById("cancelResetBtn"); + const confirmResetBtn = document.getElementById("confirmResetBtn"); + + if (closeResetModalBtn) + closeResetModalBtn.addEventListener("click", () => + closeModal(resetConfirmModal) + ); + if (cancelResetBtn) + cancelResetBtn.addEventListener("click", () => + closeModal(resetConfirmModal) + ); + if (confirmResetBtn) { + confirmResetBtn.addEventListener("click", () => { + closeModal(resetConfirmModal); + executeReset(); + }); + } + + // Click outside modal to close + window.addEventListener("click", (event) => { + const modals = [ + apiKeyModal, + resetConfirmModal, + bulkDeleteApiKeyModal, + proxyModal, + bulkDeleteProxyModal, + vertexApiKeyModal, // 新增 + bulkDeleteVertexApiKeyModal, // 新增 + modelHelperModal, + ]; + modals.forEach((modal) => { + if (event.target === modal) { + closeModal(modal); + } + }); + }); + + // Removed static token generation button event listener, now handled dynamically if needed or by specific buttons. + + // Authentication token generation button + const generateAuthTokenBtn = document.getElementById("generateAuthTokenBtn"); + const authTokenInput = document.getElementById("AUTH_TOKEN"); + if (generateAuthTokenBtn && authTokenInput) { + generateAuthTokenBtn.addEventListener("click", function () { + const newToken = generateRandomToken(); // Assuming generateRandomToken is defined elsewhere + authTokenInput.value = newToken; + if (authTokenInput.classList.contains(SENSITIVE_INPUT_CLASS)) { + const event = new Event("focusout", { + bubbles: true, + cancelable: true, + }); + authTokenInput.dispatchEvent(event); + } + showNotification("已生成新认证令牌", "success"); + }); + } + + // Event delegation for THINKING_MODELS input changes to update budget map keys + if (thinkingModelsContainer) { + thinkingModelsContainer.addEventListener("input", function (event) { + const target = event.target; + if ( + target && + target.classList.contains(ARRAY_INPUT_CLASS) && + target.closest(`.${ARRAY_ITEM_CLASS}[data-model-id]`) + ) { + const modelInput = target; + const modelItem = modelInput.closest(`.${ARRAY_ITEM_CLASS}`); + const modelId = modelItem.getAttribute("data-model-id"); + const budgetKeyInput = document.querySelector( + `.${MAP_KEY_INPUT_CLASS}[data-model-id="${modelId}"]` + ); + if (budgetKeyInput) { + budgetKeyInput.value = modelInput.value; + } + } + }); + } + + // Event delegation for dynamically added remove buttons and generate token buttons within array items + if (configForm) { + // Ensure configForm exists before adding event listener + configForm.addEventListener("click", function (event) { + const target = event.target; + const removeButton = target.closest(".remove-btn"); + const generateButton = target.closest(".generate-btn"); + + if (removeButton && removeButton.closest(`.${ARRAY_ITEM_CLASS}`)) { + const arrayItem = removeButton.closest(`.${ARRAY_ITEM_CLASS}`); + const parentContainer = arrayItem.parentElement; + const isThinkingModelItem = + arrayItem.hasAttribute("data-model-id") && + parentContainer && + parentContainer.id === "THINKING_MODELS_container"; + const isSafetySettingItem = arrayItem.classList.contains( + SAFETY_SETTING_ITEM_CLASS + ); + + if (isThinkingModelItem) { + const modelId = arrayItem.getAttribute("data-model-id"); + const budgetMapItem = document.querySelector( + `.${MAP_ITEM_CLASS}[data-model-id="${modelId}"]` + ); + if (budgetMapItem) { + budgetMapItem.remove(); + } + // Check and add placeholder for budget map if empty + const budgetContainer = document.getElementById( + "THINKING_BUDGET_MAP_container" + ); + if (budgetContainer && budgetContainer.children.length === 0) { + budgetContainer.innerHTML = + '
请在上方添加思考模型,预算将自动关联。
'; + } + } + arrayItem.remove(); + // Check and add placeholder for safety settings if empty + if ( + isSafetySettingItem && + parentContainer && + parentContainer.children.length === 0 + ) { + parentContainer.innerHTML = + '
定义模型的安全过滤阈值。
'; + } + } else if ( + generateButton && + generateButton.closest(`.${ARRAY_ITEM_CLASS}`) + ) { + const inputField = generateButton + .closest(`.${ARRAY_ITEM_CLASS}`) + .querySelector(`.${ARRAY_INPUT_CLASS}`); + if (inputField) { + const newToken = generateRandomToken(); + inputField.value = newToken; + if (inputField.classList.contains(SENSITIVE_INPUT_CLASS)) { + const event = new Event("focusout", { + bubbles: true, + cancelable: true, + }); + inputField.dispatchEvent(event); + } + showNotification("已生成新令牌", "success"); + } + } + }); + } + + // Add Safety Setting button + const addSafetySettingBtn = document.getElementById("addSafetySettingBtn"); + if (addSafetySettingBtn) { + addSafetySettingBtn.addEventListener("click", () => addSafetySettingItem()); + } + + initializeSensitiveFields(); // Initialize sensitive field handling + + // Vertex API Key Modal Elements and Events + const addVertexApiKeyBtn = document.getElementById("addVertexApiKeyBtn"); + const closeVertexApiKeyModalBtn = document.getElementById( + "closeVertexApiKeyModalBtn" + ); + const cancelAddVertexApiKeyBtn = document.getElementById( + "cancelAddVertexApiKeyBtn" + ); + const confirmAddVertexApiKeyBtn = document.getElementById( + "confirmAddVertexApiKeyBtn" + ); + const bulkDeleteVertexApiKeyBtn = document.getElementById( + "bulkDeleteVertexApiKeyBtn" + ); + const closeBulkDeleteVertexModalBtn = document.getElementById( + "closeBulkDeleteVertexModalBtn" + ); + const cancelBulkDeleteVertexApiKeyBtn = document.getElementById( + "cancelBulkDeleteVertexApiKeyBtn" + ); + const confirmBulkDeleteVertexApiKeyBtn = document.getElementById( + "confirmBulkDeleteVertexApiKeyBtn" + ); + + if (addVertexApiKeyBtn) { + addVertexApiKeyBtn.addEventListener("click", () => { + openModal(vertexApiKeyModal); + if (vertexApiKeyBulkInput) vertexApiKeyBulkInput.value = ""; + }); + } + if (closeVertexApiKeyModalBtn) + closeVertexApiKeyModalBtn.addEventListener("click", () => + closeModal(vertexApiKeyModal) + ); + if (cancelAddVertexApiKeyBtn) + cancelAddVertexApiKeyBtn.addEventListener("click", () => + closeModal(vertexApiKeyModal) + ); + if (confirmAddVertexApiKeyBtn) + confirmAddVertexApiKeyBtn.addEventListener( + "click", + handleBulkAddVertexApiKeys + ); + + if (bulkDeleteVertexApiKeyBtn) { + bulkDeleteVertexApiKeyBtn.addEventListener("click", () => { + openModal(bulkDeleteVertexApiKeyModal); + if (bulkDeleteVertexApiKeyInput) bulkDeleteVertexApiKeyInput.value = ""; + }); + } + if (closeBulkDeleteVertexModalBtn) + closeBulkDeleteVertexModalBtn.addEventListener("click", () => + closeModal(bulkDeleteVertexApiKeyModal) + ); + if (cancelBulkDeleteVertexApiKeyBtn) + cancelBulkDeleteVertexApiKeyBtn.addEventListener("click", () => + closeModal(bulkDeleteVertexApiKeyModal) + ); + if (confirmBulkDeleteVertexApiKeyBtn) + confirmBulkDeleteVertexApiKeyBtn.addEventListener( + "click", + handleBulkDeleteVertexApiKeys + ); + + // Model Helper Modal Event Listeners + if (closeModelHelperModalBtn) { + closeModelHelperModalBtn.addEventListener("click", () => + closeModal(modelHelperModal) + ); + } + if (cancelModelHelperBtn) { + cancelModelHelperBtn.addEventListener("click", () => + closeModal(modelHelperModal) + ); + } + if (modelHelperSearchInput) { + modelHelperSearchInput.addEventListener("input", () => + renderModelsInModal() + ); + } + + // Add event listeners to all model helper trigger buttons + const modelHelperTriggerBtns = document.querySelectorAll( + ".model-helper-trigger-btn" + ); + modelHelperTriggerBtns.forEach((btn) => { + btn.addEventListener("click", () => { + const targetInputId = btn.dataset.targetInputId; + const targetArrayKey = btn.dataset.targetArrayKey; + + if (targetInputId) { + currentModelHelperTarget = { + type: "input", + target: document.getElementById(targetInputId), + }; + } else if (targetArrayKey) { + currentModelHelperTarget = { type: "array", targetKey: targetArrayKey }; + } + openModelHelperModal(); + }); + }); +}); // <-- DOMContentLoaded end + +/** + * Initializes sensitive input field behavior (masking/unmasking). + */ +function initializeSensitiveFields() { + if (!configForm) return; + + // Helper function: Mask field + function maskField(field) { + if (field.value && field.value !== MASKED_VALUE) { + field.setAttribute("data-real-value", field.value); + field.value = MASKED_VALUE; + } else if (!field.value) { + // If field value is empty string + field.removeAttribute("data-real-value"); + // Ensure empty value doesn't show as asterisks + if (field.value === MASKED_VALUE) field.value = ""; + } + } + + // Helper function: Unmask field + function unmaskField(field) { + if (field.hasAttribute("data-real-value")) { + field.value = field.getAttribute("data-real-value"); + } + // If no data-real-value and value is MASKED_VALUE, it might be an initial empty sensitive field, clear it + else if ( + field.value === MASKED_VALUE && + !field.hasAttribute("data-real-value") + ) { + field.value = ""; + } + } + + // Initial masking for existing sensitive fields on page load + // This function is called after populateForm and after dynamic element additions (via event delegation) + function initialMaskAllExisting() { + const sensitiveFields = configForm.querySelectorAll( + `.${SENSITIVE_INPUT_CLASS}` + ); + sensitiveFields.forEach((field) => { + if (field.type === "password") { + // For password fields, browser handles it. We just ensure data-original-type is set + // and if it has a value, we also store data-real-value so it can be shown when switched to text + if (field.value) { + field.setAttribute("data-real-value", field.value); + } + // No need to set to MASKED_VALUE as browser handles it. + } else if ( + field.type === "text" || + field.tagName.toLowerCase() === "textarea" + ) { + maskField(field); + } + }); + } + initialMaskAllExisting(); + + // Event delegation for dynamic and static fields + configForm.addEventListener("focusin", function (event) { + const target = event.target; + if (target.classList.contains(SENSITIVE_INPUT_CLASS)) { + if (target.type === "password") { + // Record original type to switch back on blur + if (!target.hasAttribute("data-original-type")) { + target.setAttribute("data-original-type", "password"); + } + target.type = "text"; // Switch to text type to show content + // If data-real-value exists (e.g., set during populateForm), use it + if (target.hasAttribute("data-real-value")) { + target.value = target.getAttribute("data-real-value"); + } + // Otherwise, the browser's existing password value will be shown directly + } else { + // For type="text" or textarea + unmaskField(target); + } + } + }); + + configForm.addEventListener("focusout", function (event) { + const target = event.target; + if (target.classList.contains(SENSITIVE_INPUT_CLASS)) { + // First, if the field is currently text and has a value, update data-real-value + if ( + target.type === "text" || + target.tagName.toLowerCase() === "textarea" + ) { + if (target.value && target.value !== MASKED_VALUE) { + target.setAttribute("data-real-value", target.value); + } else if (!target.value) { + // If value is empty, remove data-real-value + target.removeAttribute("data-real-value"); + } + } + + // Then handle type switching and masking + if ( + target.getAttribute("data-original-type") === "password" && + target.type === "text" + ) { + target.type = "password"; // Switch back to password type + // For password type, browser handles masking automatically, no need to set MASKED_VALUE manually + // data-real-value has already been updated by the logic above + } else if ( + target.type === "text" || + target.tagName.toLowerCase() === "textarea" + ) { + // For text or textarea sensitive fields, perform masking + maskField(target); + } + } + }); +} + +/** + * Generates a UUID. + * @returns {string} A new UUID. + */ +function generateUUID() { + return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, function (c) { + var r = (Math.random() * 16) | 0, + v = c == "x" ? r : (r & 0x3) | 0x8; + return v.toString(16); + }); +} + +/** + * Initializes the configuration by fetching it from the server and populating the form. + */ +async function initConfig() { + try { + showNotification("正在加载配置...", "info"); + const response = await fetch("/api/config"); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const config = await response.json(); + + // 确保数组字段有默认值 + if ( + !config.API_KEYS || + !Array.isArray(config.API_KEYS) || + config.API_KEYS.length === 0 + ) { + config.API_KEYS = ["请在此处输入 API 密钥"]; + } + + if ( + !config.ALLOWED_TOKENS || + !Array.isArray(config.ALLOWED_TOKENS) || + config.ALLOWED_TOKENS.length === 0 + ) { + config.ALLOWED_TOKENS = [""]; + } + + if ( + !config.IMAGE_MODELS || + !Array.isArray(config.IMAGE_MODELS) || + config.IMAGE_MODELS.length === 0 + ) { + config.IMAGE_MODELS = ["gemini-1.5-pro-latest"]; + } + + if ( + !config.SEARCH_MODELS || + !Array.isArray(config.SEARCH_MODELS) || + config.SEARCH_MODELS.length === 0 + ) { + config.SEARCH_MODELS = ["gemini-1.5-flash-latest"]; + } + + if ( + !config.FILTERED_MODELS || + !Array.isArray(config.FILTERED_MODELS) || + config.FILTERED_MODELS.length === 0 + ) { + config.FILTERED_MODELS = ["gemini-1.0-pro-latest"]; + } + // --- 新增:处理 VERTEX_API_KEYS 默认值 --- + if (!config.VERTEX_API_KEYS || !Array.isArray(config.VERTEX_API_KEYS)) { + config.VERTEX_API_KEYS = []; + } + // --- 新增:处理 VERTEX_EXPRESS_BASE_URL 默认值 --- + if (typeof config.VERTEX_EXPRESS_BASE_URL === "undefined") { + config.VERTEX_EXPRESS_BASE_URL = ""; + } + // --- 新增:处理 PROXIES 默认值 --- + if (!config.PROXIES || !Array.isArray(config.PROXIES)) { + config.PROXIES = []; // 默认为空数组 + } + // --- 新增:处理新字段的默认值 --- + if (!config.THINKING_MODELS || !Array.isArray(config.THINKING_MODELS)) { + config.THINKING_MODELS = []; // 默认为空数组 + } + if ( + !config.THINKING_BUDGET_MAP || + typeof config.THINKING_BUDGET_MAP !== "object" || + config.THINKING_BUDGET_MAP === null + ) { + config.THINKING_BUDGET_MAP = {}; // 默认为空对象 + } + // --- 新增:处理 SAFETY_SETTINGS 默认值 --- + if (!config.SAFETY_SETTINGS || !Array.isArray(config.SAFETY_SETTINGS)) { + config.SAFETY_SETTINGS = []; // 默认为空数组 + } + // --- 结束:处理 SAFETY_SETTINGS 默认值 --- + + // --- 新增:处理自动删除错误日志配置的默认值 --- + if (typeof config.AUTO_DELETE_ERROR_LOGS_ENABLED === "undefined") { + config.AUTO_DELETE_ERROR_LOGS_ENABLED = false; + } + if (typeof config.AUTO_DELETE_ERROR_LOGS_DAYS === "undefined") { + config.AUTO_DELETE_ERROR_LOGS_DAYS = 7; + } + // --- 结束:处理自动删除错误日志配置的默认值 --- + + // --- 新增:处理自动删除请求日志配置的默认值 --- + if (typeof config.AUTO_DELETE_REQUEST_LOGS_ENABLED === "undefined") { + config.AUTO_DELETE_REQUEST_LOGS_ENABLED = false; + } + if (typeof config.AUTO_DELETE_REQUEST_LOGS_DAYS === "undefined") { + config.AUTO_DELETE_REQUEST_LOGS_DAYS = 30; + } + // --- 结束:处理自动删除请求日志配置的默认值 --- + + // --- 新增:处理假流式配置的默认值 --- + if (typeof config.FAKE_STREAM_ENABLED === "undefined") { + config.FAKE_STREAM_ENABLED = false; + } + if (typeof config.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS === "undefined") { + config.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS = 5; + } + // --- 结束:处理假流式配置的默认值 --- + + populateForm(config); + // After populateForm, initialize masking for all populated sensitive fields + if (configForm) { + // Ensure form exists + initializeSensitiveFields(); // Call initializeSensitiveFields to handle initial masking + } + + // Ensure upload provider has a default value + const uploadProvider = document.getElementById("UPLOAD_PROVIDER"); + if (uploadProvider && !uploadProvider.value) { + uploadProvider.value = "smms"; // 设置默认值为 smms + toggleProviderConfig("smms"); + } + + showNotification("配置加载成功", "success"); + } catch (error) { + console.error("加载配置失败:", error); + showNotification("加载配置失败: " + error.message, "error"); + + // 加载失败时,使用默认配置 + const defaultConfig = { + API_KEYS: [""], + ALLOWED_TOKENS: [""], + IMAGE_MODELS: ["gemini-1.5-pro-latest"], + SEARCH_MODELS: ["gemini-1.5-flash-latest"], + FILTERED_MODELS: ["gemini-1.0-pro-latest"], + UPLOAD_PROVIDER: "smms", + PROXIES: [], + VERTEX_API_KEYS: [], // 确保默认值存在 + VERTEX_EXPRESS_BASE_URL: "", // 确保默认值存在 + THINKING_MODELS: [], + THINKING_BUDGET_MAP: {}, + AUTO_DELETE_ERROR_LOGS_ENABLED: false, + AUTO_DELETE_ERROR_LOGS_DAYS: 7, // 新增默认值 + AUTO_DELETE_REQUEST_LOGS_ENABLED: false, // 新增默认值 + AUTO_DELETE_REQUEST_LOGS_DAYS: 30, // 新增默认值 + // --- 新增:处理假流式配置的默认值 --- + FAKE_STREAM_ENABLED: false, + FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: 5, + // --- 结束:处理假流式配置的默认值 --- + }; + + populateForm(defaultConfig); + if (configForm) { + // Ensure form exists + initializeSensitiveFields(); // Call initializeSensitiveFields to handle initial masking + } + toggleProviderConfig("smms"); + } +} + +/** + * Populates the configuration form with data. + * @param {object} config - The configuration object. + */ +function populateForm(config) { + const modelIdMap = {}; // modelName -> modelId + + // 1. Clear existing dynamic content first + const arrayContainers = document.querySelectorAll(".array-container"); + arrayContainers.forEach((container) => { + container.innerHTML = ""; // Clear all array containers + }); + const budgetMapContainer = document.getElementById( + "THINKING_BUDGET_MAP_container" + ); + if (budgetMapContainer) { + budgetMapContainer.innerHTML = ""; // Clear budget map container + } else { + console.error("Critical: THINKING_BUDGET_MAP_container not found!"); + return; // Cannot proceed + } + + // 2. Populate THINKING_MODELS and build the map + if (Array.isArray(config.THINKING_MODELS)) { + const container = document.getElementById("THINKING_MODELS_container"); + if (container) { + config.THINKING_MODELS.forEach((modelName) => { + if (modelName && typeof modelName === "string" && modelName.trim()) { + const trimmedModelName = modelName.trim(); + const modelId = addArrayItemWithValue( + "THINKING_MODELS", + trimmedModelName + ); + if (modelId) { + modelIdMap[trimmedModelName] = modelId; + } else { + console.warn( + `Failed to get modelId for THINKING_MODEL: '${trimmedModelName}'` + ); + } + } else { + console.warn(`Invalid THINKING_MODEL entry found:`, modelName); + } + }); + } else { + console.error("Critical: THINKING_MODELS_container not found!"); + } + } + + // 3. Populate THINKING_BUDGET_MAP using the map + let budgetItemsAdded = false; + if ( + config.THINKING_BUDGET_MAP && + typeof config.THINKING_BUDGET_MAP === "object" + ) { + for (const [modelName, budgetValue] of Object.entries( + config.THINKING_BUDGET_MAP + )) { + if (modelName && typeof modelName === "string") { + const trimmedModelName = modelName.trim(); + const modelId = modelIdMap[trimmedModelName]; // Look up the ID + if (modelId) { + createAndAppendBudgetMapItem(trimmedModelName, budgetValue, modelId); + budgetItemsAdded = true; + } else { + console.warn( + `Budget map: Could not find model ID for '${trimmedModelName}'. Skipping budget item.` + ); + } + } else { + console.warn(`Invalid key found in THINKING_BUDGET_MAP:`, modelName); + } + } + } + if (!budgetItemsAdded && budgetMapContainer) { + budgetMapContainer.innerHTML = + '
请在上方添加思考模型,预算将自动关联。
'; + } + + // 4. Populate other array fields (excluding THINKING_MODELS) + for (const [key, value] of Object.entries(config)) { + if (Array.isArray(value) && key !== "THINKING_MODELS") { + const container = document.getElementById(`${key}_container`); + if (container) { + value.forEach((itemValue) => { + if (typeof itemValue === "string") { + addArrayItemWithValue(key, itemValue); + } else { + console.warn(`Invalid item found in array '${key}':`, itemValue); + } + }); + } + } + } + + // 5. Populate non-array/non-budget fields + for (const [key, value] of Object.entries(config)) { + if ( + !Array.isArray(value) && + !( + typeof value === "object" && + value !== null && + key === "THINKING_BUDGET_MAP" + ) + ) { + const element = document.getElementById(key); + if (element) { + if (element.type === "checkbox" && typeof value === "boolean") { + element.checked = value; + } else if (element.type !== "checkbox") { + if (key === "LOG_LEVEL" && typeof value === "string") { + element.value = value.toUpperCase(); + } else { + element.value = value !== null && value !== undefined ? value : ""; + } + } + } + } + } + + // 6. Initialize upload provider + const uploadProvider = document.getElementById("UPLOAD_PROVIDER"); + if (uploadProvider) { + toggleProviderConfig(uploadProvider.value); + } + + // Populate SAFETY_SETTINGS + let safetyItemsAdded = false; + if (safetySettingsContainer && Array.isArray(config.SAFETY_SETTINGS)) { + config.SAFETY_SETTINGS.forEach((setting) => { + if ( + setting && + typeof setting === "object" && + setting.category && + setting.threshold + ) { + addSafetySettingItem(setting.category, setting.threshold); + safetyItemsAdded = true; + } else { + console.warn("Invalid safety setting item found:", setting); + } + }); + } + if (safetySettingsContainer && !safetyItemsAdded) { + safetySettingsContainer.innerHTML = + '
定义模型的安全过滤阈值。
'; + } + + // --- 新增:处理自动删除错误日志的字段 --- + const autoDeleteEnabledCheckbox = document.getElementById( + "AUTO_DELETE_ERROR_LOGS_ENABLED" + ); + const autoDeleteDaysSelect = document.getElementById( + "AUTO_DELETE_ERROR_LOGS_DAYS" + ); + + if (autoDeleteEnabledCheckbox && autoDeleteDaysSelect) { + autoDeleteEnabledCheckbox.checked = !!config.AUTO_DELETE_ERROR_LOGS_ENABLED; // 确保是布尔值 + autoDeleteDaysSelect.value = config.AUTO_DELETE_ERROR_LOGS_DAYS || 7; // 默认7天 + + // 根据复选框状态设置下拉框的禁用状态 + autoDeleteDaysSelect.disabled = !autoDeleteEnabledCheckbox.checked; + + // 添加事件监听器 + autoDeleteEnabledCheckbox.addEventListener("change", function () { + autoDeleteDaysSelect.disabled = !this.checked; + }); + } + // --- 结束:处理自动删除错误日志的字段 --- + + // --- 新增:处理自动删除请求日志的字段 --- + const autoDeleteRequestEnabledCheckbox = document.getElementById( + "AUTO_DELETE_REQUEST_LOGS_ENABLED" + ); + const autoDeleteRequestDaysSelect = document.getElementById( + "AUTO_DELETE_REQUEST_LOGS_DAYS" + ); + + if (autoDeleteRequestEnabledCheckbox && autoDeleteRequestDaysSelect) { + autoDeleteRequestEnabledCheckbox.checked = + !!config.AUTO_DELETE_REQUEST_LOGS_ENABLED; + autoDeleteRequestDaysSelect.value = + config.AUTO_DELETE_REQUEST_LOGS_DAYS || 30; + autoDeleteRequestDaysSelect.disabled = + !autoDeleteRequestEnabledCheckbox.checked; + + autoDeleteRequestEnabledCheckbox.addEventListener("change", function () { + autoDeleteRequestDaysSelect.disabled = !this.checked; + }); + } + // --- 结束:处理自动删除请求日志的字段 --- + + // --- 新增:处理假流式配置的字段 --- + const fakeStreamEnabledCheckbox = document.getElementById( + "FAKE_STREAM_ENABLED" + ); + const fakeStreamIntervalInput = document.getElementById( + "FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS" + ); + + if (fakeStreamEnabledCheckbox && fakeStreamIntervalInput) { + fakeStreamEnabledCheckbox.checked = !!config.FAKE_STREAM_ENABLED; + fakeStreamIntervalInput.value = + config.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS || 5; + // 根据复选框状态设置输入框的禁用状态 (如果需要) + // fakeStreamIntervalInput.disabled = !fakeStreamEnabledCheckbox.checked; + // fakeStreamEnabledCheckbox.addEventListener("change", function () { + // fakeStreamIntervalInput.disabled = !this.checked; + // }); + } + // --- 结束:处理假流式配置的字段 --- +} + +/** + * Handles the bulk addition of API keys from the modal input. + */ +function handleBulkAddApiKeys() { + const apiKeyContainer = document.getElementById("API_KEYS_container"); + if (!apiKeyBulkInput || !apiKeyContainer || !apiKeyModal) return; + + const bulkText = apiKeyBulkInput.value; + const extractedKeys = bulkText.match(API_KEY_REGEX) || []; + + const currentKeyInputs = apiKeyContainer.querySelectorAll( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + let currentKeys = Array.from(currentKeyInputs) + .map((input) => { + return input.hasAttribute("data-real-value") + ? input.getAttribute("data-real-value") + : input.value; + }) + .filter((key) => key && key.trim() !== "" && key !== MASKED_VALUE); + + const combinedKeys = new Set([...currentKeys, ...extractedKeys]); + const uniqueKeys = Array.from(combinedKeys); + + apiKeyContainer.innerHTML = ""; // Clear existing items more directly + + uniqueKeys.forEach((key) => { + addArrayItemWithValue("API_KEYS", key); + }); + + const newKeyInputs = apiKeyContainer.querySelectorAll( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + newKeyInputs.forEach((input) => { + if (configForm && typeof initializeSensitiveFields === "function") { + const focusoutEvent = new Event("focusout", { + bubbles: true, + cancelable: true, + }); + input.dispatchEvent(focusoutEvent); + } + }); + + closeModal(apiKeyModal); + showNotification(`添加/更新了 ${uniqueKeys.length} 个唯一密钥`, "success"); +} + +/** + * Handles searching/filtering of API keys in the list. + */ +function handleApiKeySearch() { + const apiKeyContainer = document.getElementById("API_KEYS_container"); + if (!apiKeySearchInput || !apiKeyContainer) return; + + const searchTerm = apiKeySearchInput.value.toLowerCase(); + const keyItems = apiKeyContainer.querySelectorAll(`.${ARRAY_ITEM_CLASS}`); + + keyItems.forEach((item) => { + const input = item.querySelector( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + if (input) { + const realValue = input.hasAttribute("data-real-value") + ? input.getAttribute("data-real-value").toLowerCase() + : input.value.toLowerCase(); + item.style.display = realValue.includes(searchTerm) ? "flex" : "none"; + } + }); +} + +/** + * Handles the bulk deletion of API keys based on input from the modal. + */ +function handleBulkDeleteApiKeys() { + const apiKeyContainer = document.getElementById("API_KEYS_container"); + if (!bulkDeleteApiKeyInput || !apiKeyContainer || !bulkDeleteApiKeyModal) + return; + + const bulkText = bulkDeleteApiKeyInput.value; + if (!bulkText.trim()) { + showNotification("请粘贴需要删除的 API 密钥", "warning"); + return; + } + + const keysToDelete = new Set(bulkText.match(API_KEY_REGEX) || []); + + if (keysToDelete.size === 0) { + showNotification("未在输入内容中提取到有效的 API 密钥格式", "warning"); + return; + } + + const keyItems = apiKeyContainer.querySelectorAll(`.${ARRAY_ITEM_CLASS}`); + let deleteCount = 0; + + keyItems.forEach((item) => { + const input = item.querySelector( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + const realValue = + input && + (input.hasAttribute("data-real-value") + ? input.getAttribute("data-real-value") + : input.value); + if (realValue && keysToDelete.has(realValue)) { + item.remove(); + deleteCount++; + } + }); + + closeModal(bulkDeleteApiKeyModal); + + if (deleteCount > 0) { + showNotification(`成功删除了 ${deleteCount} 个匹配的密钥`, "success"); + } else { + showNotification("列表中未找到您输入的任何密钥进行删除", "info"); + } + bulkDeleteApiKeyInput.value = ""; +} + +/** + * Handles the bulk addition of proxies from the modal input. + */ +function handleBulkAddProxies() { + const proxyContainer = document.getElementById("PROXIES_container"); + if (!proxyBulkInput || !proxyContainer || !proxyModal) return; + + const bulkText = proxyBulkInput.value; + const extractedProxies = bulkText.match(PROXY_REGEX) || []; + + const currentProxyInputs = proxyContainer.querySelectorAll( + `.${ARRAY_INPUT_CLASS}` + ); + const currentProxies = Array.from(currentProxyInputs) + .map((input) => input.value) + .filter((proxy) => proxy.trim() !== ""); + + const combinedProxies = new Set([...currentProxies, ...extractedProxies]); + const uniqueProxies = Array.from(combinedProxies); + + proxyContainer.innerHTML = ""; // Clear existing items + + uniqueProxies.forEach((proxy) => { + addArrayItemWithValue("PROXIES", proxy); + }); + + closeModal(proxyModal); + showNotification(`添加/更新了 ${uniqueProxies.length} 个唯一代理`, "success"); +} + +/** + * Handles the bulk deletion of proxies based on input from the modal. + */ +function handleBulkDeleteProxies() { + const proxyContainer = document.getElementById("PROXIES_container"); + if (!bulkDeleteProxyInput || !proxyContainer || !bulkDeleteProxyModal) return; + + const bulkText = bulkDeleteProxyInput.value; + if (!bulkText.trim()) { + showNotification("请粘贴需要删除的代理地址", "warning"); + return; + } + + const proxiesToDelete = new Set(bulkText.match(PROXY_REGEX) || []); + + if (proxiesToDelete.size === 0) { + showNotification("未在输入内容中提取到有效的代理地址格式", "warning"); + return; + } + + const proxyItems = proxyContainer.querySelectorAll(`.${ARRAY_ITEM_CLASS}`); + let deleteCount = 0; + + proxyItems.forEach((item) => { + const input = item.querySelector(`.${ARRAY_INPUT_CLASS}`); + if (input && proxiesToDelete.has(input.value)) { + item.remove(); + deleteCount++; + } + }); + + closeModal(bulkDeleteProxyModal); + + if (deleteCount > 0) { + showNotification(`成功删除了 ${deleteCount} 个匹配的代理`, "success"); + } else { + showNotification("列表中未找到您输入的任何代理进行删除", "info"); + } + bulkDeleteProxyInput.value = ""; +} + +/** + * Handles the bulk addition of Vertex API keys from the modal input. + */ +function handleBulkAddVertexApiKeys() { + const vertexApiKeyContainer = document.getElementById( + "VERTEX_API_KEYS_container" + ); + if ( + !vertexApiKeyBulkInput || + !vertexApiKeyContainer || + !vertexApiKeyModal + ) { + return; + } + + const bulkText = vertexApiKeyBulkInput.value; + const extractedKeys = bulkText.match(VERTEX_API_KEY_REGEX) || []; + + const currentKeyInputs = vertexApiKeyContainer.querySelectorAll( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + let currentKeys = Array.from(currentKeyInputs) + .map((input) => { + return input.hasAttribute("data-real-value") + ? input.getAttribute("data-real-value") + : input.value; + }) + .filter((key) => key && key.trim() !== "" && key !== MASKED_VALUE); + + const combinedKeys = new Set([...currentKeys, ...extractedKeys]); + const uniqueKeys = Array.from(combinedKeys); + + vertexApiKeyContainer.innerHTML = ""; // Clear existing items + + uniqueKeys.forEach((key) => { + addArrayItemWithValue("VERTEX_API_KEYS", key); // VERTEX_API_KEYS are sensitive + }); + + // Ensure new sensitive inputs are masked + const newKeyInputs = vertexApiKeyContainer.querySelectorAll( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + newKeyInputs.forEach((input) => { + if (configForm && typeof initializeSensitiveFields === "function") { + const focusoutEvent = new Event("focusout", { + bubbles: true, + cancelable: true, + }); + input.dispatchEvent(focusoutEvent); + } + }); + + closeModal(vertexApiKeyModal); + showNotification( + `添加/更新了 ${uniqueKeys.length} 个唯一 Vertex 密钥`, + "success" + ); + vertexApiKeyBulkInput.value = ""; +} + +/** + * Handles the bulk deletion of Vertex API keys based on input from the modal. + */ +function handleBulkDeleteVertexApiKeys() { + const vertexApiKeyContainer = document.getElementById( + "VERTEX_API_KEYS_container" + ); + if ( + !bulkDeleteVertexApiKeyInput || + !vertexApiKeyContainer || + !bulkDeleteVertexApiKeyModal + ) { + return; + } + + const bulkText = bulkDeleteVertexApiKeyInput.value; + if (!bulkText.trim()) { + showNotification("请粘贴需要删除的 Vertex API 密钥", "warning"); + return; + } + + const keysToDelete = new Set(bulkText.match(VERTEX_API_KEY_REGEX) || []); + + if (keysToDelete.size === 0) { + showNotification( + "未在输入内容中提取到有效的 Vertex API 密钥格式", + "warning" + ); + return; + } + + const keyItems = vertexApiKeyContainer.querySelectorAll(`.${ARRAY_ITEM_CLASS}`); + let deleteCount = 0; + + keyItems.forEach((item) => { + const input = item.querySelector( + `.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}` + ); + const realValue = + input && + (input.hasAttribute("data-real-value") + ? input.getAttribute("data-real-value") + : input.value); + if (realValue && keysToDelete.has(realValue)) { + item.remove(); + deleteCount++; + } + }); + + closeModal(bulkDeleteVertexApiKeyModal); + + if (deleteCount > 0) { + showNotification(`成功删除了 ${deleteCount} 个匹配的 Vertex 密钥`, "success"); + } else { + showNotification("列表中未找到您输入的任何 Vertex 密钥进行删除", "info"); + } + bulkDeleteVertexApiKeyInput.value = ""; +} + +/** + * Switches the active configuration tab. + * @param {string} tabId - The ID of the tab to switch to. + */ +function switchTab(tabId) { + console.log(`Switching to tab: ${tabId}`); + + // 定义选中态和未选中态的样式 + const activeStyle = "background-color: #3b82f6 !important; color: #ffffff !important; border: 2px solid #2563eb !important; box-shadow: 0 4px 12px -2px rgba(59, 130, 246, 0.4), 0 2px 6px -1px rgba(59, 130, 246, 0.2) !important; transform: translateY(-2px) !important; font-weight: 600 !important;"; + const inactiveStyle = "background-color: #f8fafc !important; color: #64748b !important; border: 2px solid #e2e8f0 !important; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1) !important; font-weight: 500 !important; transform: none !important;"; + + // 更新标签按钮状态 + const tabButtons = document.querySelectorAll(".tab-btn"); + console.log(`Found ${tabButtons.length} tab buttons`); + + tabButtons.forEach((button) => { + const buttonTabId = button.getAttribute("data-tab"); + if (buttonTabId === tabId) { + // 激活状态:直接设置内联样式 + button.classList.add("active"); + button.setAttribute("style", activeStyle); + console.log(`Applied active style to button: ${buttonTabId}`); + } else { + // 非激活状态:直接设置内联样式 + button.classList.remove("active"); + button.setAttribute("style", inactiveStyle); + console.log(`Applied inactive style to button: ${buttonTabId}`); + } + }); + + // 更新内容区域 + const sections = document.querySelectorAll(".config-section"); + sections.forEach((section) => { + if (section.id === `${tabId}-section`) { + section.classList.add("active"); + } else { + section.classList.remove("active"); + } + }); +} + +/** + * Toggles the visibility of configuration sections for different upload providers. + * @param {string} provider - The selected upload provider. + */ +function toggleProviderConfig(provider) { + const providerConfigs = document.querySelectorAll(".provider-config"); + providerConfigs.forEach((config) => { + if (config.getAttribute("data-provider") === provider) { + config.classList.add("active"); + } else { + config.classList.remove("active"); + } + }); +} + +/** + * Creates and appends an input field for an array item. + * @param {string} key - The configuration key for the array. + * @param {string} value - The initial value for the input field. + * @param {boolean} isSensitive - Whether the input is for sensitive data. + * @param {string|null} modelId - Optional model ID for thinking models. + * @returns {HTMLInputElement} The created input element. + */ +function createArrayInput(key, value, isSensitive, modelId = null) { + const input = document.createElement("input"); + input.type = "text"; + input.name = `${key}[]`; // Used for form submission if not handled by JS + input.value = value; + let inputClasses = `${ARRAY_INPUT_CLASS} flex-grow px-3 py-2 border-none rounded-l-md focus:outline-none form-input-themed`; + if (isSensitive) { + inputClasses += ` ${SENSITIVE_INPUT_CLASS}`; + } + input.className = inputClasses; + if (modelId) { + input.setAttribute("data-model-id", modelId); + input.placeholder = "思考模型名称"; + } + return input; +} + +/** + * Creates a generate token button for allowed tokens. + * @returns {HTMLButtonElement} The created button element. + */ +function createGenerateTokenButton() { + const generateBtn = document.createElement("button"); + generateBtn.type = "button"; + generateBtn.className = + "generate-btn px-2 py-2 text-gray-500 hover:text-primary-600 focus:outline-none rounded-r-md bg-gray-100 hover:bg-gray-200 transition-colors"; + generateBtn.innerHTML = ''; + generateBtn.title = "生成随机令牌"; + // Event listener will be added via delegation in DOMContentLoaded + return generateBtn; +} + +/** + * Creates a remove button for an array item. + * @returns {HTMLButtonElement} The created button element. + */ +function createRemoveButton() { + const removeBtn = document.createElement("button"); + removeBtn.type = "button"; + removeBtn.className = + "remove-btn text-gray-400 hover:text-red-500 focus:outline-none transition-colors duration-150"; + removeBtn.innerHTML = ''; + removeBtn.title = "删除"; + // Event listener will be added via delegation in DOMContentLoaded + return removeBtn; +} + +/** + * Adds a new item to an array configuration section (e.g., API_KEYS, ALLOWED_TOKENS). + * This function is typically called by a "+" button. + * @param {string} key - The configuration key for the array (e.g., 'API_KEYS'). + */ +function addArrayItem(key) { + const container = document.getElementById(`${key}_container`); + if (!container) return; + + const newItemValue = ""; // New items start empty + const modelId = addArrayItemWithValue(key, newItemValue); // This adds the DOM element + + if (key === "THINKING_MODELS" && modelId) { + createAndAppendBudgetMapItem(newItemValue, 0, modelId); // Default budget 0 + } +} + +/** + * Adds an array item with a specific value to the DOM. + * This is used both for initially populating the form and for adding new items. + * @param {string} key - The configuration key (e.g., 'API_KEYS', 'THINKING_MODELS'). + * @param {string} value - The value for the array item. + * @returns {string|null} The generated modelId if it's a thinking model, otherwise null. + */ +function addArrayItemWithValue(key, value) { + const container = document.getElementById(`${key}_container`); + if (!container) return null; + + const isThinkingModel = key === "THINKING_MODELS"; + const isAllowedToken = key === "ALLOWED_TOKENS"; + const isVertexApiKey = key === "VERTEX_API_KEYS"; // 新增判断 + const isSensitive = + key === "API_KEYS" || isAllowedToken || isVertexApiKey; // 更新敏感判断 + const modelId = isThinkingModel ? generateUUID() : null; + + const arrayItem = document.createElement("div"); + arrayItem.className = `${ARRAY_ITEM_CLASS} flex items-center mb-2 gap-2`; + if (isThinkingModel) { + arrayItem.setAttribute("data-model-id", modelId); + } + + const inputWrapper = document.createElement("div"); + inputWrapper.className = + "flex items-center flex-grow rounded-md focus-within:border-blue-500 focus-within:ring focus-within:ring-blue-500 focus-within:ring-opacity-50"; + // Apply light theme border directly via style + inputWrapper.style.border = "1px solid rgba(0, 0, 0, 0.12)"; + inputWrapper.style.backgroundColor = "transparent"; // Ensure wrapper is transparent + + const input = createArrayInput( + key, + value, + isSensitive, + isThinkingModel ? modelId : null + ); + inputWrapper.appendChild(input); + + if (isAllowedToken) { + const generateBtn = createGenerateTokenButton(); + inputWrapper.appendChild(generateBtn); + } else { + // Ensure right-side rounding if no button is present + input.classList.add("rounded-r-md"); + } + + const removeBtn = createRemoveButton(); + + arrayItem.appendChild(inputWrapper); + arrayItem.appendChild(removeBtn); + container.appendChild(arrayItem); + + // Initialize sensitive field if applicable + if (isSensitive && input.value) { + if (configForm && typeof initializeSensitiveFields === "function") { + const focusoutEvent = new Event("focusout", { + bubbles: true, + cancelable: true, + }); + input.dispatchEvent(focusoutEvent); + } + } + return isThinkingModel ? modelId : null; +} + +/** + * Creates and appends a DOM element for a thinking model's budget mapping. + * @param {string} mapKey - The model name (key for the map). + * @param {number|string} mapValue - The budget value. + * @param {string} modelId - The unique ID of the corresponding thinking model. + */ +function createAndAppendBudgetMapItem(mapKey, mapValue, modelId) { + const container = document.getElementById("THINKING_BUDGET_MAP_container"); + if (!container) { + console.error( + "Cannot add budget item: THINKING_BUDGET_MAP_container not found!" + ); + return; + } + + // If container currently only has the placeholder, clear it + const placeholder = container.querySelector(".text-gray-500.italic"); + // Check if the only child is the placeholder before clearing + if ( + placeholder && + container.children.length === 1 && + container.firstChild === placeholder + ) { + container.innerHTML = ""; + } + + const mapItem = document.createElement("div"); + mapItem.className = `${MAP_ITEM_CLASS} flex items-center mb-2 gap-2`; + mapItem.setAttribute("data-model-id", modelId); + + const keyInput = document.createElement("input"); + keyInput.type = "text"; + keyInput.value = mapKey; + keyInput.placeholder = "模型名称 (自动关联)"; + keyInput.readOnly = true; + keyInput.className = `${MAP_KEY_INPUT_CLASS} flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none bg-gray-100 text-gray-500`; + keyInput.setAttribute("data-model-id", modelId); + + const valueInput = document.createElement("input"); + valueInput.type = "number"; + const intValue = parseInt(mapValue, 10); + valueInput.value = isNaN(intValue) ? 0 : intValue; + valueInput.placeholder = "预算 (整数)"; + valueInput.className = `${MAP_VALUE_INPUT_CLASS} w-24 px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50`; + valueInput.min = -1; + valueInput.max = 32767; + valueInput.addEventListener("input", function () { + let val = this.value.replace(/[^0-9-]/g, ""); + if (val !== "") { + val = parseInt(val, 10); + if (val < -1) val = -1; + if (val > 32767) val = 32767; + } + this.value = val; // Corrected variable name + }); + + // Remove Button - Removed for budget map items + // const removeBtn = document.createElement('button'); + // removeBtn.type = 'button'; + // removeBtn.className = 'remove-btn text-gray-300 cursor-not-allowed focus:outline-none'; // Kept original class for reference + // removeBtn.innerHTML = ''; + // removeBtn.title = '请从上方模型列表删除'; + // removeBtn.disabled = true; + + mapItem.appendChild(keyInput); + mapItem.appendChild(valueInput); + // mapItem.appendChild(removeBtn); // Do not append the remove button + + container.appendChild(mapItem); +} + +/** + * Collects all data from the configuration form. + * @returns {object} An object containing all configuration data. + */ +function collectFormData() { + const formData = {}; + + // 处理普通输入和 select + const inputsAndSelects = document.querySelectorAll( + 'input[type="text"], input[type="number"], input[type="password"], select, textarea' + ); + inputsAndSelects.forEach((element) => { + if ( + element.name && + !element.name.includes("[]") && + !element.closest(".array-container") && + !element.closest(`.${MAP_ITEM_CLASS}`) && + !element.closest(`.${SAFETY_SETTING_ITEM_CLASS}`) + ) { + if (element.type === "number") { + formData[element.name] = parseFloat(element.value); + } else if ( + element.classList.contains(SENSITIVE_INPUT_CLASS) && + element.hasAttribute("data-real-value") + ) { + formData[element.name] = element.getAttribute("data-real-value"); + } else { + formData[element.name] = element.value; + } + } + }); + + const checkboxes = document.querySelectorAll('input[type="checkbox"]'); + checkboxes.forEach((checkbox) => { + formData[checkbox.name] = checkbox.checked; + }); + + const arrayContainers = document.querySelectorAll(".array-container"); + arrayContainers.forEach((container) => { + const key = container.id.replace("_container", ""); + const arrayInputs = container.querySelectorAll(`.${ARRAY_INPUT_CLASS}`); + formData[key] = Array.from(arrayInputs) + .map((input) => { + if ( + input.classList.contains(SENSITIVE_INPUT_CLASS) && + input.hasAttribute("data-real-value") + ) { + return input.getAttribute("data-real-value"); + } + return input.value; + }) + .filter( + (value) => value && value.trim() !== "" && value !== MASKED_VALUE + ); // Ensure MASKED_VALUE is also filtered if not handled + }); + + const budgetMapContainer = document.getElementById( + "THINKING_BUDGET_MAP_container" + ); + if (budgetMapContainer) { + formData["THINKING_BUDGET_MAP"] = {}; + const mapItems = budgetMapContainer.querySelectorAll(`.${MAP_ITEM_CLASS}`); + mapItems.forEach((item) => { + const keyInput = item.querySelector(`.${MAP_KEY_INPUT_CLASS}`); + const valueInput = item.querySelector(`.${MAP_VALUE_INPUT_CLASS}`); + if (keyInput && valueInput && keyInput.value.trim() !== "") { + const budgetValue = parseInt(valueInput.value, 10); + formData["THINKING_BUDGET_MAP"][keyInput.value.trim()] = isNaN( + budgetValue + ) + ? 0 + : budgetValue; + } + }); + } + + if (safetySettingsContainer) { + formData["SAFETY_SETTINGS"] = []; + const settingItems = safetySettingsContainer.querySelectorAll( + `.${SAFETY_SETTING_ITEM_CLASS}` + ); + settingItems.forEach((item) => { + const categorySelect = item.querySelector(".safety-category-select"); + const thresholdSelect = item.querySelector(".safety-threshold-select"); + if ( + categorySelect && + thresholdSelect && + categorySelect.value && + thresholdSelect.value + ) { + formData["SAFETY_SETTINGS"].push({ + category: categorySelect.value, + threshold: thresholdSelect.value, + }); + } + }); + } + + // --- 新增:收集自动删除错误日志的配置 --- + const autoDeleteEnabledCheckbox = document.getElementById( + "AUTO_DELETE_ERROR_LOGS_ENABLED" + ); + if (autoDeleteEnabledCheckbox) { + formData["AUTO_DELETE_ERROR_LOGS_ENABLED"] = + autoDeleteEnabledCheckbox.checked; + } + + const autoDeleteDaysSelect = document.getElementById( + "AUTO_DELETE_ERROR_LOGS_DAYS" + ); + if (autoDeleteDaysSelect) { + // 如果复选框未选中,则不应提交天数,或者可以提交一个默认/无效值, + // 但后端应该只在 ENABLED 为 true 时才关心 DAYS。 + // 这里我们总是收集它,后端逻辑会处理。 + formData["AUTO_DELETE_ERROR_LOGS_DAYS"] = parseInt( + autoDeleteDaysSelect.value, + 10 + ); + } + // --- 结束:收集自动删除错误日志的配置 --- + + // --- 新增:收集自动删除请求日志的配置 --- + const autoDeleteRequestEnabledCheckbox = document.getElementById( + "AUTO_DELETE_REQUEST_LOGS_ENABLED" + ); + if (autoDeleteRequestEnabledCheckbox) { + formData["AUTO_DELETE_REQUEST_LOGS_ENABLED"] = + autoDeleteRequestEnabledCheckbox.checked; + } + + const autoDeleteRequestDaysSelect = document.getElementById( + "AUTO_DELETE_REQUEST_LOGS_DAYS" + ); + if (autoDeleteRequestDaysSelect) { + formData["AUTO_DELETE_REQUEST_LOGS_DAYS"] = parseInt( + autoDeleteRequestDaysSelect.value, + 10 + ); + } + // --- 结束:收集自动删除请求日志的配置 --- + + // --- 新增:收集假流式配置 --- + const fakeStreamEnabledCheckbox = document.getElementById( + "FAKE_STREAM_ENABLED" + ); + if (fakeStreamEnabledCheckbox) { + formData["FAKE_STREAM_ENABLED"] = fakeStreamEnabledCheckbox.checked; + } + const fakeStreamIntervalInput = document.getElementById( + "FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS" + ); + if (fakeStreamIntervalInput) { + formData["FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS"] = parseInt( + fakeStreamIntervalInput.value, + 10 + ); + } + // --- 结束:收集假流式配置 --- + + return formData; +} + +/** + * Stops the scheduler task on the server. + */ +async function stopScheduler() { + try { + const response = await fetch("/api/scheduler/stop", { method: "POST" }); + if (!response.ok) { + console.warn(`停止定时任务失败: ${response.status}`); + } else { + console.log("定时任务已停止"); + } + } catch (error) { + console.error("调用停止定时任务API时出错:", error); + } +} + +/** + * Starts the scheduler task on the server. + */ +async function startScheduler() { + try { + const response = await fetch("/api/scheduler/start", { method: "POST" }); + if (!response.ok) { + console.warn(`启动定时任务失败: ${response.status}`); + } else { + console.log("定时任务已启动"); + } + } catch (error) { + console.error("调用启动定时任务API时出错:", error); + } +} + +/** + * Saves the current configuration to the server. + */ +async function saveConfig() { + try { + const formData = collectFormData(); + + showNotification("正在保存配置...", "info"); + + // 1. 停止定时任务 + await stopScheduler(); + + const response = await fetch("/api/config", { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(formData), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error( + errorData.detail || `HTTP error! status: ${response.status}` + ); + } + + const result = await response.json(); + + // 移除居中的 saveStatus 提示 + + showNotification("配置保存成功", "success"); + + // 3. 启动新的定时任务 + await startScheduler(); + } catch (error) { + console.error("保存配置失败:", error); + // 保存失败时,也尝试重启定时任务,以防万一 + await startScheduler(); + // 移除居中的 saveStatus 提示 + + showNotification("保存配置失败: " + error.message, "error"); + } +} + +/** + * Initiates the configuration reset process by showing a confirmation modal. + * @param {Event} [event] - The click event, if triggered by a button. + */ +function resetConfig(event) { + // 阻止事件冒泡和默认行为 + if (event) { + event.preventDefault(); + event.stopPropagation(); + } + + console.log( + "resetConfig called. Event target:", + event ? event.target.id : "No event" + ); + + // Ensure modal is shown only if the event comes from the reset button + if ( + !event || + event.target.id === "resetBtn" || + (event.currentTarget && event.currentTarget.id === "resetBtn") + ) { + if (resetConfirmModal) { + openModal(resetConfirmModal); + } else { + console.error( + "Reset confirmation modal not found! Falling back to default confirm." + ); + if (confirm("确定要重置所有配置吗?这将恢复到默认值。")) { + executeReset(); + } + } + } +} + +/** + * Executes the actual configuration reset after confirmation. + */ +async function executeReset() { + try { + showNotification("正在重置配置...", "info"); + + // 1. 停止定时任务 + await stopScheduler(); + const response = await fetch("/api/config/reset", { method: "POST" }); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const config = await response.json(); + populateForm(config); + // Re-initialize masking for sensitive fields after reset + if (configForm && typeof initializeSensitiveFields === "function") { + const sensitiveFields = configForm.querySelectorAll( + `.${SENSITIVE_INPUT_CLASS}` + ); + sensitiveFields.forEach((field) => { + if (field.type === "password") { + if (field.value) field.setAttribute("data-real-value", field.value); + } else if ( + field.type === "text" || + field.tagName.toLowerCase() === "textarea" + ) { + const focusoutEvent = new Event("focusout", { + bubbles: true, + cancelable: true, + }); + field.dispatchEvent(focusoutEvent); + } + }); + } + showNotification("配置已重置为默认值", "success"); + + // 3. 启动新的定时任务 + await startScheduler(); + } catch (error) { + console.error("重置配置失败:", error); + showNotification("重置配置失败: " + error.message, "error"); + // 重置失败时,也尝试重启定时任务 + await startScheduler(); + } +} + +/** + * Displays a notification message to the user. + * @param {string} message - The message to display. + * @param {string} [type='info'] - The type of notification ('info', 'success', 'error', 'warning'). + */ +function showNotification(message, type = "info") { + const notification = document.getElementById("notification"); + notification.textContent = message; + + // 统一样式为黑色半透明,与 keys_status.js 保持一致 + notification.classList.remove("bg-danger-500"); + notification.classList.add("bg-black"); + notification.style.backgroundColor = "rgba(0,0,0,0.8)"; + notification.style.color = "#fff"; + + // 应用过渡效果 + notification.style.opacity = "1"; + notification.style.transform = "translate(-50%, 0)"; + + // 设置自动消失 + setTimeout(() => { + notification.style.opacity = "0"; + notification.style.transform = "translate(-50%, 10px)"; + }, 3000); +} + +/** + * Refreshes the current page. + * @param {HTMLButtonElement} [button] - The button that triggered the refresh (to show loading state). + */ +function refreshPage(button) { + if (button) button.classList.add("loading"); + location.reload(); +} + +/** + * Scrolls the page to the top. + */ +function scrollToTop() { + window.scrollTo({ top: 0, behavior: "smooth" }); +} + +/** + * Scrolls the page to the bottom. + */ +function scrollToBottom() { + window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" }); +} + +/** + * Toggles the visibility of scroll-to-top/bottom buttons based on scroll position. + */ +function toggleScrollButtons() { + const scrollButtons = document.querySelector(".scroll-buttons"); + if (scrollButtons) { + scrollButtons.style.display = window.scrollY > 200 ? "flex" : "none"; + } +} + +/** + * Generates a random token string. + * @returns {string} A randomly generated token. + */ +function generateRandomToken() { + const characters = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_"; + const length = 48; + let result = "sk-"; + for (let i = 0; i < length; i++) { + result += characters.charAt(Math.floor(Math.random() * characters.length)); + } + return result; +} + +/** + * Adds a new safety setting item to the DOM. + * @param {string} [category=''] - The initial category for the setting. + * @param {string} [threshold=''] - The initial threshold for the setting. + */ +function addSafetySettingItem(category = "", threshold = "") { + const container = document.getElementById("SAFETY_SETTINGS_container"); + if (!container) { + console.error( + "Cannot add safety setting: SAFETY_SETTINGS_container not found!" + ); + return; + } + + // 如果容器当前只有占位符,则清除它 + const placeholder = container.querySelector(".text-gray-500.italic"); + if ( + placeholder && + container.children.length === 1 && + container.firstChild === placeholder + ) { + container.innerHTML = ""; + } + + const harmCategories = [ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_CIVIC_INTEGRITY", // 根据需要添加或移除 + ]; + const harmThresholds = [ + "BLOCK_NONE", + "BLOCK_LOW_AND_ABOVE", + "BLOCK_MEDIUM_AND_ABOVE", + "BLOCK_ONLY_HIGH", + "OFF", // 根据 Google API 文档添加或移除 + ]; + + const settingItem = document.createElement("div"); + settingItem.className = `${SAFETY_SETTING_ITEM_CLASS} flex items-center mb-2 gap-2`; + + const categorySelect = document.createElement("select"); + categorySelect.className = + "safety-category-select flex-grow px-3 py-2 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 form-select-themed"; + harmCategories.forEach((cat) => { + const option = document.createElement("option"); + option.value = cat; + option.textContent = cat.replace("HARM_CATEGORY_", ""); + if (cat === category) option.selected = true; + categorySelect.appendChild(option); + }); + + const thresholdSelect = document.createElement("select"); + thresholdSelect.className = + "safety-threshold-select w-48 px-3 py-2 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50 form-select-themed"; + harmThresholds.forEach((thr) => { + const option = document.createElement("option"); + option.value = thr; + option.textContent = thr.replace("BLOCK_", "").replace("_AND_ABOVE", "+"); + if (thr === threshold) option.selected = true; + thresholdSelect.appendChild(option); + }); + + const removeBtn = document.createElement("button"); + removeBtn.type = "button"; + removeBtn.className = + "remove-btn text-gray-400 hover:text-red-500 focus:outline-none transition-colors duration-150"; + removeBtn.innerHTML = ''; + removeBtn.title = "删除此设置"; + // Event listener for removeBtn is now handled by event delegation in DOMContentLoaded + + settingItem.appendChild(categorySelect); + settingItem.appendChild(thresholdSelect); + settingItem.appendChild(removeBtn); + + container.appendChild(settingItem); +} + +// --- Model Helper Functions --- +async function fetchModels() { + if (cachedModelsList) { + return cachedModelsList; + } + try { + showNotification("正在从 /api/config/ui/models 加载模型列表...", "info"); + const response = await fetch("/api/config/ui/models"); + if (!response.ok) { + const errorData = await response.text(); + throw new Error(`HTTP error ${response.status}: ${errorData}`); + } + const responseData = await response.json(); // Changed variable name to responseData + // The backend returns an object like: { object: "list", data: [{id: "m1"}, {id: "m2"}], success: true } + if ( + responseData && + responseData.success && + Array.isArray(responseData.data) + ) { + cachedModelsList = responseData.data; // Use responseData.data + showNotification("模型列表加载成功", "success"); + return cachedModelsList; + } else { + console.error("Invalid model list format received:", responseData); + throw new Error("模型列表格式无效或请求未成功"); + } + } catch (error) { + console.error("加载模型列表失败:", error); + showNotification(`加载模型列表失败: ${error.message}`, "error"); + cachedModelsList = []; // Avoid repeated fetches on error for this session, or set to null to retry + return []; + } +} + +function renderModelsInModal() { + if (!modelHelperListContainer) return; + if (!cachedModelsList) { + modelHelperListContainer.innerHTML = + '

模型列表尚未加载。

'; + return; + } + + const searchTerm = modelHelperSearchInput.value.toLowerCase(); + const filteredModels = cachedModelsList.filter((model) => + model.id.toLowerCase().includes(searchTerm) + ); + + modelHelperListContainer.innerHTML = ""; // Clear previous items + + if (filteredModels.length === 0) { + modelHelperListContainer.innerHTML = + '

未找到匹配的模型。

'; + return; + } + + filteredModels.forEach((model) => { + const modelItemElement = document.createElement("button"); + modelItemElement.type = "button"; + modelItemElement.textContent = model.id; + modelItemElement.className = + "block w-full text-left px-4 py-2 rounded-md hover:bg-blue-100 focus:bg-blue-100 focus:outline-none transition-colors text-gray-700 hover:text-gray-800"; + // Add any other classes for styling, e.g., from existing modals or array items + + modelItemElement.addEventListener("click", () => + handleModelSelection(model.id) + ); + modelHelperListContainer.appendChild(modelItemElement); + }); +} + +async function openModelHelperModal() { + if (!currentModelHelperTarget) { + console.error("Model helper target not set."); + showNotification("无法打开模型助手:目标未设置", "error"); + return; + } + + await fetchModels(); // Ensure models are loaded + renderModelsInModal(); // Render them (handles empty/error cases internally) + + if (modelHelperTitleElement) { + if ( + currentModelHelperTarget.type === "input" && + currentModelHelperTarget.target + ) { + const label = document.querySelector( + `label[for="${currentModelHelperTarget.target.id}"]` + ); + modelHelperTitleElement.textContent = label + ? `为 "${label.textContent.trim()}" 选择模型` + : "选择模型"; + } else if (currentModelHelperTarget.type === "array") { + modelHelperTitleElement.textContent = `为 ${currentModelHelperTarget.targetKey} 添加模型`; + } else { + modelHelperTitleElement.textContent = "选择模型"; + } + } + if (modelHelperSearchInput) modelHelperSearchInput.value = ""; // Clear search on open + if (modelHelperModal) openModal(modelHelperModal); +} + +function handleModelSelection(selectedModelId) { + if (!currentModelHelperTarget) return; + + if ( + currentModelHelperTarget.type === "input" && + currentModelHelperTarget.target + ) { + const inputElement = currentModelHelperTarget.target; + inputElement.value = selectedModelId; + // If the input is a sensitive field, dispatch focusout to trigger masking behavior if needed + if (inputElement.classList.contains(SENSITIVE_INPUT_CLASS)) { + const event = new Event("focusout", { bubbles: true, cancelable: true }); + inputElement.dispatchEvent(event); + } + // Dispatch input event for any other listeners + inputElement.dispatchEvent(new Event("input", { bubbles: true })); + } else if ( + currentModelHelperTarget.type === "array" && + currentModelHelperTarget.targetKey + ) { + const modelId = addArrayItemWithValue( + currentModelHelperTarget.targetKey, + selectedModelId + ); + if (currentModelHelperTarget.targetKey === "THINKING_MODELS" && modelId) { + // Automatically add corresponding budget map item with default budget 0 + createAndAppendBudgetMapItem(selectedModelId, 0, modelId); + } + } + + if (modelHelperModal) closeModal(modelHelperModal); + currentModelHelperTarget = null; // Reset target +} + +// -- End Model Helper Functions -- diff --git a/app/static/js/error_logs.js b/app/static/js/error_logs.js new file mode 100644 index 0000000000000000000000000000000000000000..c1d84bb95c166a68fcc383951aeda9bdd94d1643 --- /dev/null +++ b/app/static/js/error_logs.js @@ -0,0 +1,1182 @@ +// 错误日志页面JavaScript (Updated for new structure, no Bootstrap) + +// 页面滚动功能 +function scrollToTop() { + window.scrollTo({ top: 0, behavior: "smooth" }); +} + +function scrollToBottom() { + window.scrollTo({ top: document.body.scrollHeight, behavior: "smooth" }); +} + +// API 调用辅助函数 +async function fetchAPI(url, options = {}) { + try { + const response = await fetch(url, options); + + // Handle cases where response might be empty but still ok (e.g., 204 No Content for DELETE) + if (response.status === 204) { + return null; // Indicate success with no content + } + + let responseData; + try { + responseData = await response.json(); + } catch (e) { + // Handle non-JSON responses if necessary, or assume error if JSON expected + if (!response.ok) { + // If response is not ok and not JSON, use statusText + throw new Error( + `HTTP error! status: ${response.status} - ${response.statusText}` + ); + } + // If response is ok but not JSON, maybe return raw text or handle differently + // For now, let's assume successful non-JSON is not expected or handled later + console.warn("Response was not JSON for URL:", url); + return await response.text(); // Or handle as needed + } + + if (!response.ok) { + // Prefer error message from API response body if available + const message = + responseData?.detail || + `HTTP error! status: ${response.status} - ${response.statusText}`; + throw new Error(message); + } + + return responseData; // Return parsed JSON data for successful responses + } catch (error) { + // Catch network errors or errors thrown from above + console.error( + "API Call Failed:", + error.message, + "URL:", + url, + "Options:", + options + ); + // Re-throw the error so the calling function knows the operation failed + throw error; + } +} + +// Refresh function removed as the buttons are gone. +// If refresh functionality is needed elsewhere, it can be triggered directly by calling loadErrorLogs(). + +// 全局状态管理 +let errorLogState = { + currentPage: 1, + pageSize: 10, + logs: [], // 存储获取的日志 + sort: { + field: "id", // 默认按 ID 排序 + order: "desc", // 默认降序 + }, + search: { + key: "", + error: "", + errorCode: "", + startDate: "", + endDate: "", + }, +}; + +// DOM Elements Cache +let pageSizeSelector; +// let refreshBtn; // Removed, as the button is deleted +let tableBody; +let paginationElement; +let loadingIndicator; +let noDataMessage; +let errorMessage; +let logDetailModal; +let modalCloseBtns; // Collection of close buttons for the modal +let keySearchInput; +let errorSearchInput; +let errorCodeSearchInput; // Added error code input +let startDateInput; +let endDateInput; +let searchBtn; +let pageInput; +let goToPageBtn; +let selectAllCheckbox; // 新增:全选复选框 +let copySelectedKeysBtn; // 新增:复制选中按钮 +let deleteSelectedBtn; // 新增:批量删除按钮 +let sortByIdHeader; // 新增:ID 排序表头 +let sortIcon; // 新增:排序图标 +let selectedCountSpan; // 新增:选中计数显示 +let deleteConfirmModal; // 新增:删除确认模态框 +let closeDeleteConfirmModalBtn; // 新增:关闭删除模态框按钮 +let cancelDeleteBtn; // 新增:取消删除按钮 +let confirmDeleteBtn; // 新增:确认删除按钮 +let deleteConfirmMessage; // 新增:删除确认消息元素 +let idsToDeleteGlobally = []; // 新增:存储待删除的ID +let currentConfirmCallback = null; // 新增:存储当前的确认回调 +let deleteAllLogsBtn; // 新增:清空全部按钮 + +// Helper functions for initialization +function cacheDOMElements() { + pageSizeSelector = document.getElementById("pageSize"); + tableBody = document.getElementById("errorLogsTable"); + paginationElement = document.getElementById("pagination"); + loadingIndicator = document.getElementById("loadingIndicator"); + noDataMessage = document.getElementById("noDataMessage"); + errorMessage = document.getElementById("errorMessage"); + logDetailModal = document.getElementById("logDetailModal"); + modalCloseBtns = document.querySelectorAll( + "#closeLogDetailModalBtn, #closeModalFooterBtn" + ); + keySearchInput = document.getElementById("keySearch"); + errorSearchInput = document.getElementById("errorSearch"); + errorCodeSearchInput = document.getElementById("errorCodeSearch"); + startDateInput = document.getElementById("startDate"); + endDateInput = document.getElementById("endDate"); + searchBtn = document.getElementById("searchBtn"); + pageInput = document.getElementById("pageInput"); + goToPageBtn = document.getElementById("goToPageBtn"); + selectAllCheckbox = document.getElementById("selectAllCheckbox"); + copySelectedKeysBtn = document.getElementById("copySelectedKeysBtn"); + deleteSelectedBtn = document.getElementById("deleteSelectedBtn"); + sortByIdHeader = document.getElementById("sortById"); + if (sortByIdHeader) { + sortIcon = sortByIdHeader.querySelector("i"); + } + selectedCountSpan = document.getElementById("selectedCount"); + deleteConfirmModal = document.getElementById("deleteConfirmModal"); + closeDeleteConfirmModalBtn = document.getElementById( + "closeDeleteConfirmModalBtn" + ); + cancelDeleteBtn = document.getElementById("cancelDeleteBtn"); + confirmDeleteBtn = document.getElementById("confirmDeleteBtn"); + deleteConfirmMessage = document.getElementById("deleteConfirmMessage"); + deleteAllLogsBtn = document.getElementById("deleteAllLogsBtn"); // 缓存清空全部按钮 + } + + function initializePageSizeControls() { + if (pageSizeSelector) { + pageSizeSelector.value = errorLogState.pageSize; + pageSizeSelector.addEventListener("change", function () { + errorLogState.pageSize = parseInt(this.value); + errorLogState.currentPage = 1; // Reset to first page + loadErrorLogs(); + }); + } +} + +function initializeSearchControls() { + if (searchBtn) { + searchBtn.addEventListener("click", function () { + errorLogState.search.key = keySearchInput + ? keySearchInput.value.trim() + : ""; + errorLogState.search.error = errorSearchInput + ? errorSearchInput.value.trim() + : ""; + errorLogState.search.errorCode = errorCodeSearchInput + ? errorCodeSearchInput.value.trim() + : ""; + errorLogState.search.startDate = startDateInput + ? startDateInput.value + : ""; + errorLogState.search.endDate = endDateInput ? endDateInput.value : ""; + errorLogState.currentPage = 1; // Reset to first page on new search + loadErrorLogs(); + }); + } +} + +function initializeModalControls() { + // Log Detail Modal + if (logDetailModal && modalCloseBtns) { + modalCloseBtns.forEach((btn) => { + btn.addEventListener("click", closeLogDetailModal); + }); + logDetailModal.addEventListener("click", function (event) { + if (event.target === logDetailModal) { + closeLogDetailModal(); + } + }); + } + + // Delete Confirm Modal + if (closeDeleteConfirmModalBtn) { + closeDeleteConfirmModalBtn.addEventListener( + "click", + hideDeleteConfirmModal + ); + } + if (cancelDeleteBtn) { + cancelDeleteBtn.addEventListener("click", hideDeleteConfirmModal); + } + if (confirmDeleteBtn) { + confirmDeleteBtn.addEventListener("click", handleConfirmDelete); + } + if (deleteConfirmModal) { + deleteConfirmModal.addEventListener("click", function (event) { + if (event.target === deleteConfirmModal) { + hideDeleteConfirmModal(); + } + }); + } +} + +function initializePaginationJumpControls() { + if (goToPageBtn && pageInput) { + goToPageBtn.addEventListener("click", function () { + const targetPage = parseInt(pageInput.value); + if (!isNaN(targetPage) && targetPage >= 1) { + errorLogState.currentPage = targetPage; + loadErrorLogs(); + pageInput.value = ""; + } else { + showNotification("请输入有效的页码", "error", 2000); + pageInput.value = ""; + } + }); + pageInput.addEventListener("keypress", function (event) { + if (event.key === "Enter") { + goToPageBtn.click(); + } + }); + } +} + +function initializeActionControls() { + if (deleteSelectedBtn) { + deleteSelectedBtn.addEventListener("click", handleDeleteSelected); + } + if (sortByIdHeader) { + sortByIdHeader.addEventListener("click", handleSortById); + } + // Bulk selection listeners are closely related to actions + setupBulkSelectionListeners(); + + // 为 "清空全部" 按钮添加事件监听器 + if (deleteAllLogsBtn) { + deleteAllLogsBtn.addEventListener("click", function() { + const message = "您确定要清空所有错误日志吗?此操作不可恢复!"; + showDeleteConfirmModal(message, handleDeleteAllLogs); // 传入回调 + }); + } + } + + // 新增:处理 "清空全部" 逻辑的函数 + async function handleDeleteAllLogs() { + const url = "/api/logs/errors/all"; + const options = { + method: "DELETE", + }; + + try { + await fetchAPI(url, options); + showNotification("所有错误日志已成功清空", "success"); + if (selectAllCheckbox) selectAllCheckbox.checked = false; // 取消全选 + loadErrorLogs(); // 重新加载日志 + } catch (error) { + console.error("清空所有错误日志失败:", error); + showNotification(`清空失败: ${error.message}`, "error", 5000); + } + } + + // 页面加载完成后执行 +document.addEventListener("DOMContentLoaded", function () { + cacheDOMElements(); + initializePageSizeControls(); + initializeSearchControls(); + initializeModalControls(); + initializePaginationJumpControls(); + initializeActionControls(); + + // Initial load of error logs + loadErrorLogs(); + + // Add event listeners for copy buttons inside the modal and table + // This needs to be called after initial render and potentially after each render if content is dynamic + setupCopyButtons(); +}); + +// 新增:显示删除确认模态框 +function showDeleteConfirmModal(message, confirmCallback) { + if (deleteConfirmModal && deleteConfirmMessage) { + deleteConfirmMessage.textContent = message; + currentConfirmCallback = confirmCallback; // 存储回调 + deleteConfirmModal.classList.add("show"); + document.body.style.overflow = "hidden"; // Prevent body scrolling + } +} + +// 新增:隐藏删除确认模态框 +function hideDeleteConfirmModal() { + if (deleteConfirmModal) { + deleteConfirmModal.classList.remove("show"); + document.body.style.overflow = ""; // Restore body scrolling + idsToDeleteGlobally = []; // 清空待删除ID + currentConfirmCallback = null; // 清除回调 + } +} + +// 新增:处理确认删除按钮点击 +function handleConfirmDelete() { + if (typeof currentConfirmCallback === 'function') { + currentConfirmCallback(); // 调用存储的回调 + } + hideDeleteConfirmModal(); // 关闭模态框 +} + +// Fallback copy function using document.execCommand +function fallbackCopyTextToClipboard(text) { + const textArea = document.createElement("textarea"); + textArea.value = text; + + // Avoid scrolling to bottom + textArea.style.top = "0"; + textArea.style.left = "0"; + textArea.style.position = "fixed"; + + document.body.appendChild(textArea); + textArea.focus(); + textArea.select(); + + let successful = false; + try { + successful = document.execCommand("copy"); + } catch (err) { + console.error("Fallback copy failed:", err); + successful = false; + } + + document.body.removeChild(textArea); + return successful; +} + +// Helper function to handle feedback after copy attempt (both modern and fallback) +function handleCopyResult(buttonElement, success) { + const originalIcon = buttonElement.querySelector("i").className; // Store original icon class + const iconElement = buttonElement.querySelector("i"); + if (success) { + iconElement.className = "fas fa-check text-success-500"; // Use checkmark icon class + showNotification("已复制到剪贴板", "success", 2000); + } else { + iconElement.className = "fas fa-times text-danger-500"; // Use error icon class + showNotification("复制失败", "error", 3000); + } + setTimeout( + () => { + iconElement.className = originalIcon; + }, + success ? 2000 : 3000 + ); // Restore original icon class +} + +// 新的内部辅助函数,封装实际的复制操作和反馈 +function _performCopy(text, buttonElement) { + let copySuccess = false; + if (navigator.clipboard && window.isSecureContext) { + navigator.clipboard + .writeText(text) + .then(() => { + if (buttonElement) { + handleCopyResult(buttonElement, true); + } else { + showNotification("已复制到剪贴板", "success"); + } + }) + .catch((err) => { + console.error("Clipboard API failed, attempting fallback:", err); + copySuccess = fallbackCopyTextToClipboard(text); + if (buttonElement) { + handleCopyResult(buttonElement, copySuccess); + } else { + showNotification( + copySuccess ? "已复制到剪贴板" : "复制失败", + copySuccess ? "success" : "error" + ); + } + }); + } else { + console.warn( + "Clipboard API not available or context insecure. Using fallback copy method." + ); + copySuccess = fallbackCopyTextToClipboard(text); + if (buttonElement) { + handleCopyResult(buttonElement, copySuccess); + } else { + showNotification( + copySuccess ? "已复制到剪贴板" : "复制失败", + copySuccess ? "success" : "error" + ); + } + } +} + +// Function to set up copy button listeners (using modern API with fallback) - Updated to handle table copy buttons +function setupCopyButtons(containerSelector = "body") { + // Find buttons within the specified container (defaults to body) + const container = document.querySelector(containerSelector); + if (!container) return; + + const copyButtons = container.querySelectorAll(".copy-btn"); + copyButtons.forEach((button) => { + // Remove existing listener to prevent duplicates if called multiple times + button.removeEventListener("click", handleCopyButtonClick); + // Add the listener + button.addEventListener("click", handleCopyButtonClick); + }); +} + +// Extracted click handler logic for reusability and removing listeners +function handleCopyButtonClick() { + const button = this; // 'this' refers to the button clicked + const targetId = button.getAttribute("data-target"); + const textToCopyDirect = button.getAttribute("data-copy-text"); // For direct text copy (e.g., table key) + let textToCopy = ""; + + if (textToCopyDirect) { + textToCopy = textToCopyDirect; + } else if (targetId) { + const targetElement = document.getElementById(targetId); + if (targetElement) { + textToCopy = targetElement.textContent; + } else { + console.error("Target element not found:", targetId); + showNotification("复制出错:找不到目标元素", "error"); + return; // Exit if target element not found + } + } else { + console.error( + "No data-target or data-copy-text attribute found on button:", + button + ); + showNotification("复制出错:未指定复制内容", "error"); + return; // Exit if no source specified + } + + if (textToCopy) { + _performCopy(textToCopy, button); // 使用新的辅助函数 + } else { + console.warn( + "No text found to copy for target:", + targetId || "direct text" + ); + showNotification("没有内容可复制", "warning"); + } +} // End of handleCopyButtonClick function + +// 新增:设置批量选择相关的事件监听器 +function setupBulkSelectionListeners() { + if (selectAllCheckbox) { + selectAllCheckbox.addEventListener("change", handleSelectAllChange); + } + + if (tableBody) { + // 使用事件委托处理行复选框的点击 + tableBody.addEventListener("change", handleRowCheckboxChange); + } + + if (copySelectedKeysBtn) { + copySelectedKeysBtn.addEventListener("click", handleCopySelectedKeys); + } + + // 新增:为批量删除按钮添加事件监听器 (如果尚未添加) + // 通常在 DOMContentLoaded 中添加一次即可 + // if (deleteSelectedBtn && !deleteSelectedBtn.hasListener) { + // deleteSelectedBtn.addEventListener('click', handleDeleteSelected); + // deleteSelectedBtn.hasListener = true; // 标记已添加 + // } +} + +// 新增:处理"全选"复选框变化的函数 +function handleSelectAllChange() { + const isChecked = selectAllCheckbox.checked; + const rowCheckboxes = tableBody.querySelectorAll(".row-checkbox"); + rowCheckboxes.forEach((checkbox) => { + checkbox.checked = isChecked; + }); + updateSelectedState(); +} + +// 新增:处理行复选框变化的函数 (事件委托) +function handleRowCheckboxChange(event) { + if (event.target.classList.contains("row-checkbox")) { + updateSelectedState(); + } +} + +// 新增:更新选中状态(计数、按钮状态、全选框状态) +function updateSelectedState() { + const rowCheckboxes = tableBody.querySelectorAll(".row-checkbox"); + const selectedCheckboxes = tableBody.querySelectorAll( + ".row-checkbox:checked" + ); + const selectedCount = selectedCheckboxes.length; + + // 移除了数字显示,不再更新selectedCountSpan + // 仍然更新复制按钮的禁用状态 + if (copySelectedKeysBtn) { + copySelectedKeysBtn.disabled = selectedCount === 0; + + // 可选:根据选中项数量更新按钮标题属性 + copySelectedKeysBtn.setAttribute("title", `复制${selectedCount}项选中密钥`); + } + // 新增:更新批量删除按钮的禁用状态 + if (deleteSelectedBtn) { + deleteSelectedBtn.disabled = selectedCount === 0; + deleteSelectedBtn.setAttribute("title", `删除${selectedCount}项选中日志`); + } + + // 更新"全选"复选框的状态 + if (selectAllCheckbox) { + if (rowCheckboxes.length > 0 && selectedCount === rowCheckboxes.length) { + selectAllCheckbox.checked = true; + selectAllCheckbox.indeterminate = false; + } else if (selectedCount > 0) { + selectAllCheckbox.checked = false; + selectAllCheckbox.indeterminate = true; // 部分选中状态 + } else { + selectAllCheckbox.checked = false; + selectAllCheckbox.indeterminate = false; + } + } +} + +// 新增:处理"复制选中密钥"按钮点击的函数 +function handleCopySelectedKeys() { + const selectedCheckboxes = tableBody.querySelectorAll( + ".row-checkbox:checked" + ); + const keysToCopy = []; + selectedCheckboxes.forEach((checkbox) => { + const key = checkbox.getAttribute("data-key"); + if (key) { + keysToCopy.push(key); + } + }); + + if (keysToCopy.length > 0) { + const textToCopy = keysToCopy.join("\n"); // 每行一个密钥 + _performCopy(textToCopy, copySelectedKeysBtn); // 使用新的辅助函数 + } else { + showNotification("没有选中的密钥可复制", "warning"); + } +} + +// 修改:处理批量删除按钮点击的函数 - 改为显示模态框 +function handleDeleteSelected() { + const selectedCheckboxes = tableBody.querySelectorAll( + ".row-checkbox:checked" + ); + const logIdsToDelete = []; + selectedCheckboxes.forEach((checkbox) => { + const logId = checkbox.getAttribute("data-log-id"); // 需要在渲染时添加 data-log-id + if (logId) { + logIdsToDelete.push(parseInt(logId)); + } + }); + + if (logIdsToDelete.length === 0) { + showNotification("没有选中的日志可删除", "warning"); + return; + } + + if (logIdsToDelete.length === 0) { + showNotification("没有选中的日志可删除", "warning"); + return; + } + + // 存储待删除ID并显示模态框 + idsToDeleteGlobally = logIdsToDelete; // 仍然需要设置,因为 performActualDelete 会用到 + const message = `确定要删除选中的 ${logIdsToDelete.length} 条日志吗?此操作不可恢复!`; + showDeleteConfirmModal(message, function() { // 传入匿名回调 + performActualDelete(idsToDeleteGlobally); + }); +} + +// 新增:执行实际的删除操作(提取自原 handleDeleteSelected 和 handleDeleteLogRow) +async function performActualDelete(logIds) { + if (!logIds || logIds.length === 0) return; + + const isSingleDelete = logIds.length === 1; + const url = isSingleDelete + ? `/api/logs/errors/${logIds[0]}` + : "/api/logs/errors"; + const method = "DELETE"; + const body = isSingleDelete ? null : JSON.stringify({ ids: logIds }); + const headers = isSingleDelete ? {} : { "Content-Type": "application/json" }; + const options = { + method: method, + headers: headers, + body: body, // fetchAPI handles null body correctly + }; + + try { + // Use fetchAPI for the delete request + await fetchAPI(url, options); // fetchAPI returns null for 204 No Content + + // If fetchAPI doesn't throw, the request was successful + const successMessage = isSingleDelete + ? `成功删除该日志` + : `成功删除 ${logIds.length} 条日志`; + showNotification(successMessage, "success"); + // 取消全选 + if (selectAllCheckbox) selectAllCheckbox.checked = false; + // 重新加载当前页数据 + loadErrorLogs(); + } catch (error) { + console.error("批量删除错误日志失败:", error); + showNotification(`批量删除失败: ${error.message}`, "error", 5000); + } +} + +// 修改:处理单行删除按钮点击的函数 - 改为显示模态框 +function handleDeleteLogRow(logId) { + if (!logId) return; + + // 存储待删除ID并显示模态框 + idsToDeleteGlobally = [parseInt(logId)]; // 存储为数组 // 仍然需要设置,因为 performActualDelete 会用到 + // 使用通用确认消息,不显示具体ID + const message = `确定要删除这条日志吗?此操作不可恢复!`; + showDeleteConfirmModal(message, function() { // 传入匿名回调 + performActualDelete([parseInt(logId)]); // 确保传递的是数组 + }); +} + +// 新增:处理 ID 排序点击的函数 +function handleSortById() { + if (errorLogState.sort.field === "id") { + // 如果当前是按 ID 排序,切换顺序 + errorLogState.sort.order = + errorLogState.sort.order === "asc" ? "desc" : "asc"; + } else { + // 如果当前不是按 ID 排序,切换到按 ID 排序,默认为降序 + errorLogState.sort.field = "id"; + errorLogState.sort.order = "desc"; + } + // 更新图标 + updateSortIcon(); + // 重新加载第一页数据 + errorLogState.currentPage = 1; + loadErrorLogs(); +} + +// 新增:更新排序图标的函数 +function updateSortIcon() { + if (!sortIcon) return; + // 移除所有可能的排序类 + sortIcon.classList.remove( + "fa-sort", + "fa-sort-up", + "fa-sort-down", + "text-gray-400", + "text-primary-600" + ); + + if (errorLogState.sort.field === "id") { + sortIcon.classList.add( + errorLogState.sort.order === "asc" ? "fa-sort-up" : "fa-sort-down" + ); + sortIcon.classList.add("text-primary-600"); // 高亮显示 + } else { + // 如果不是按 ID 排序,显示默认图标 + sortIcon.classList.add("fa-sort", "text-gray-400"); + } +} + +// 加载错误日志数据 +async function loadErrorLogs() { + // 重置选择状态 + if (selectAllCheckbox) selectAllCheckbox.checked = false; + if (selectAllCheckbox) selectAllCheckbox.indeterminate = false; + updateSelectedState(); // 更新按钮状态和计数 + + showLoading(true); + showError(false); + showNoData(false); + + const offset = (errorLogState.currentPage - 1) * errorLogState.pageSize; + + try { + // Construct the API URL with search and sort parameters + let apiUrl = `/api/logs/errors?limit=${errorLogState.pageSize}&offset=${offset}`; + // 添加排序参数 + apiUrl += `&sort_by=${errorLogState.sort.field}&sort_order=${errorLogState.sort.order}`; + + // 添加搜索参数 + if (errorLogState.search.key) { + apiUrl += `&key_search=${encodeURIComponent(errorLogState.search.key)}`; + } + if (errorLogState.search.error) { + apiUrl += `&error_search=${encodeURIComponent( + errorLogState.search.error + )}`; + } + if (errorLogState.search.errorCode) { + // Add error code to API request + apiUrl += `&error_code_search=${encodeURIComponent( + errorLogState.search.errorCode + )}`; + } + if (errorLogState.search.startDate) { + apiUrl += `&start_date=${encodeURIComponent( + errorLogState.search.startDate + )}`; + } + if (errorLogState.search.endDate) { + apiUrl += `&end_date=${encodeURIComponent(errorLogState.search.endDate)}`; + } + + // Use fetchAPI to get logs + const data = await fetchAPI(apiUrl); + + // API 现在返回 { logs: [], total: count } + // fetchAPI already parsed JSON + if (data && Array.isArray(data.logs)) { + errorLogState.logs = data.logs; // Store the list data (contains error_code) + renderErrorLogs(errorLogState.logs); + updatePagination(errorLogState.logs.length, data.total || -1); // Use total from response + } else { + // Handle unexpected data format even after successful fetch + console.error("Unexpected API response format:", data); + throw new Error("无法识别的API响应格式"); + } + + showLoading(false); + + if (errorLogState.logs.length === 0) { + showNoData(true); + } + } catch (error) { + console.error("获取错误日志失败:", error); + showLoading(false); + showError(true, error.message); // Show specific error message + } +} + +// Helper function to create HTML for a single log row +function _createLogRowHtml(log, sequentialId) { + // Format date + let formattedTime = "N/A"; + try { + const requestTime = new Date(log.request_time); + if (!isNaN(requestTime)) { + formattedTime = requestTime.toLocaleString("zh-CN", { + year: "numeric", + month: "2-digit", + day: "2-digit", + hour: "2-digit", + minute: "2-digit", + second: "2-digit", + hour12: false, + }); + } + } catch (e) { + console.error("Error formatting date:", e); + } + + const errorCodeContent = log.error_code || "无"; + + const maskKey = (key) => { + if (!key || key.length < 8) return key || "无"; + return `${key.substring(0, 4)}...${key.substring(key.length - 4)}`; + }; + const maskedKey = maskKey(log.gemini_key); + const fullKey = log.gemini_key || ""; + + return ` + + + + ${sequentialId} + + ${maskedKey} + + + ${log.error_type || "未知"} + ${errorCodeContent} + ${log.model_name || "未知"} + ${formattedTime} + + + + + `; +} + +// 渲染错误日志表格 +function renderErrorLogs(logs) { + if (!tableBody) return; + tableBody.innerHTML = ""; // Clear previous entries + + // 重置全选复选框状态(在清空表格后) + if (selectAllCheckbox) { + selectAllCheckbox.checked = false; + selectAllCheckbox.indeterminate = false; + } + + if (!logs || logs.length === 0) { + // Handled by showNoData + return; + } + + const startIndex = (errorLogState.currentPage - 1) * errorLogState.pageSize; + + logs.forEach((log, index) => { + const sequentialId = startIndex + index + 1; + const row = document.createElement("tr"); + row.innerHTML = _createLogRowHtml(log, sequentialId); + tableBody.appendChild(row); + }); + + // Add event listeners to new 'View Details' buttons + document.querySelectorAll(".btn-view-details").forEach((button) => { + button.addEventListener("click", function () { + const logId = parseInt(this.getAttribute("data-log-id")); + showLogDetails(logId); + }); + }); + + // 新增:为新渲染的删除按钮添加事件监听器 + document.querySelectorAll(".btn-delete-row").forEach((button) => { + button.addEventListener("click", function () { + const logId = this.getAttribute("data-log-id"); + handleDeleteLogRow(logId); + }); + }); + + // Re-initialize copy buttons specifically for the newly rendered table rows + setupCopyButtons("#errorLogsTable"); + // Update selected state after rendering + updateSelectedState(); +} + +// 显示错误日志详情 (从 API 获取) +async function showLogDetails(logId) { + if (!logDetailModal) return; + + // Show loading state in modal (optional) + // Clear previous content and show a spinner or message + document.getElementById("modalGeminiKey").textContent = "加载中..."; + document.getElementById("modalErrorType").textContent = "加载中..."; + document.getElementById("modalErrorLog").textContent = "加载中..."; + document.getElementById("modalRequestMsg").textContent = "加载中..."; + document.getElementById("modalModelName").textContent = "加载中..."; + document.getElementById("modalRequestTime").textContent = "加载中..."; + + logDetailModal.classList.add("show"); + document.body.style.overflow = "hidden"; // Prevent body scrolling + + try { + // Use fetchAPI to get log details + const logDetails = await fetchAPI(`/api/logs/errors/${logId}/details`); + + // fetchAPI handles response.ok check and JSON parsing + if (!logDetails) { + // Handle case where API returns success but no data (if possible) + throw new Error("未找到日志详情"); + } + + // Format date + let formattedTime = "N/A"; + try { + const requestTime = new Date(logDetails.request_time); + if (!isNaN(requestTime)) { + formattedTime = requestTime.toLocaleString("zh-CN", { + year: "numeric", + month: "2-digit", + day: "2-digit", + hour: "2-digit", + minute: "2-digit", + second: "2-digit", + hour12: false, + }); + } + } catch (e) { + console.error("Error formatting date:", e); + } + + // Format request message (handle potential JSON) + let formattedRequestMsg = "无"; + if (logDetails.request_msg) { + try { + if ( + typeof logDetails.request_msg === "object" && + logDetails.request_msg !== null + ) { + formattedRequestMsg = JSON.stringify(logDetails.request_msg, null, 2); + } else if (typeof logDetails.request_msg === "string") { + // Try parsing if it looks like JSON, otherwise display as string + const trimmedMsg = logDetails.request_msg.trim(); + if (trimmedMsg.startsWith("{") || trimmedMsg.startsWith("[")) { + formattedRequestMsg = JSON.stringify( + JSON.parse(logDetails.request_msg), + null, + 2 + ); + } else { + formattedRequestMsg = logDetails.request_msg; + } + } else { + formattedRequestMsg = String(logDetails.request_msg); + } + } catch (e) { + formattedRequestMsg = String(logDetails.request_msg); // Fallback + console.warn("Could not parse request_msg as JSON:", e); + } + } + + // Populate modal content with fetched details + document.getElementById("modalGeminiKey").textContent = + logDetails.gemini_key || "无"; + document.getElementById("modalErrorType").textContent = + logDetails.error_type || "未知"; + document.getElementById("modalErrorLog").textContent = + logDetails.error_log || "无"; // Full error log + document.getElementById("modalRequestMsg").textContent = + formattedRequestMsg; // Full request message + document.getElementById("modalModelName").textContent = + logDetails.model_name || "未知"; + document.getElementById("modalRequestTime").textContent = formattedTime; + + // Re-initialize copy buttons specifically for the modal after content is loaded + setupCopyButtons("#logDetailModal"); + } catch (error) { + console.error("获取日志详情失败:", error); + // Show error in modal + document.getElementById("modalGeminiKey").textContent = "错误"; + document.getElementById("modalErrorType").textContent = "错误"; + document.getElementById( + "modalErrorLog" + ).textContent = `加载失败: ${error.message}`; + document.getElementById("modalRequestMsg").textContent = "错误"; + document.getElementById("modalModelName").textContent = "错误"; + document.getElementById("modalRequestTime").textContent = "错误"; + // Optionally show a notification + showNotification(`加载日志详情失败: ${error.message}`, "error", 5000); + } +} + +// Close Log Detail Modal +function closeLogDetailModal() { + if (logDetailModal) { + logDetailModal.classList.remove("show"); + // Optional: Restore body scrolling + document.body.style.overflow = ""; + } +} + +// 更新分页控件 +function updatePagination(currentItemCount, totalItems) { + if (!paginationElement) return; + paginationElement.innerHTML = ""; // Clear existing pagination + + // Calculate total pages only if totalItems is known and valid + let totalPages = 1; + if (totalItems >= 0) { + totalPages = Math.max(1, Math.ceil(totalItems / errorLogState.pageSize)); + } else if ( + currentItemCount < errorLogState.pageSize && + errorLogState.currentPage === 1 + ) { + // If less items than page size fetched on page 1, assume it's the only page + totalPages = 1; + } else { + // If total is unknown and more items might exist, we can't build full pagination + // We can show Prev/Next based on current page and if items were returned + console.warn("Total item count unknown, pagination will be limited."); + // Basic Prev/Next for unknown total + addPaginationLink( + paginationElement, + "«", + errorLogState.currentPage > 1, + () => { + errorLogState.currentPage--; + loadErrorLogs(); + } + ); + addPaginationLink( + paginationElement, + errorLogState.currentPage.toString(), + true, + null, + true + ); // Current page number (non-clickable) + addPaginationLink( + paginationElement, + "»", + currentItemCount === errorLogState.pageSize, + () => { + errorLogState.currentPage++; + loadErrorLogs(); + } + ); // Next enabled if full page was returned + return; // Exit here for limited pagination + } + + const maxPagesToShow = 5; // Max number of page links to show + let startPage = Math.max( + 1, + errorLogState.currentPage - Math.floor(maxPagesToShow / 2) + ); + let endPage = Math.min(totalPages, startPage + maxPagesToShow - 1); + + // Adjust startPage if endPage reaches the limit first + if (endPage === totalPages) { + startPage = Math.max(1, endPage - maxPagesToShow + 1); + } + + // Previous Button + addPaginationLink( + paginationElement, + "«", + errorLogState.currentPage > 1, + () => { + errorLogState.currentPage--; + loadErrorLogs(); + } + ); + + // First Page Button + if (startPage > 1) { + addPaginationLink(paginationElement, "1", true, () => { + errorLogState.currentPage = 1; + loadErrorLogs(); + }); + if (startPage > 2) { + addPaginationLink(paginationElement, "...", false); // Ellipsis + } + } + + // Page Number Buttons + for (let i = startPage; i <= endPage; i++) { + addPaginationLink( + paginationElement, + i.toString(), + true, + () => { + errorLogState.currentPage = i; + loadErrorLogs(); + }, + i === errorLogState.currentPage + ); + } + + // Last Page Button + if (endPage < totalPages) { + if (endPage < totalPages - 1) { + addPaginationLink(paginationElement, "...", false); // Ellipsis + } + addPaginationLink(paginationElement, totalPages.toString(), true, () => { + errorLogState.currentPage = totalPages; + loadErrorLogs(); + }); + } + + // Next Button + addPaginationLink( + paginationElement, + "»", + errorLogState.currentPage < totalPages, + () => { + errorLogState.currentPage++; + loadErrorLogs(); + } + ); +} + +// Helper function to add pagination links +function addPaginationLink( + parentElement, + text, + enabled, + clickHandler, + isActive = false +) { + // const pageItem = document.createElement('li'); // We are not using
  • anymore + + const pageLink = document.createElement("a"); + + // Base Tailwind classes for layout, size, and transition. Colors/borders will come from CSS. + let baseClasses = + "px-3 py-1 rounded-md text-sm transition duration-150 ease-in-out"; // Common classes + + if (isActive) { + pageLink.className = `${baseClasses} active`; // Add 'active' class for CSS + } else if (enabled) { + pageLink.className = baseClasses; // Just base classes, CSS handles the rest + } else { + // Disabled link (e.g., '...' or unavailable prev/next) + pageLink.className = `${baseClasses} disabled`; // Add 'disabled' class for CSS + } + + pageLink.href = "#"; // Prevent page jump + pageLink.innerHTML = text; + + if (enabled && clickHandler) { + pageLink.addEventListener("click", function (e) { + e.preventDefault(); + clickHandler(); + }); + } else { + // Handles !enabled (includes isActive as clickHandler is null for it, and '...' which has no clickHandler) + pageLink.addEventListener("click", (e) => e.preventDefault()); + } + + parentElement.appendChild(pageLink); // Directly append to the