Spaces:
Running
Running
""" | |
应用程序配置模块 | |
""" | |
import datetime | |
import json | |
from typing import Any, Dict, List, Type | |
from pydantic import ValidationError, ValidationInfo, field_validator | |
from pydantic_settings import BaseSettings | |
from sqlalchemy import insert, select, update | |
def parse_comma_separated_string(v: Any) -> List[str]: | |
"""解析逗号分隔的字符串为字符串列表""" | |
# Handle None or empty values | |
if v is None or v == "": | |
return [] | |
if isinstance(v, list): | |
return [str(item).strip() for item in v if str(item).strip()] | |
if isinstance(v, str): | |
# Handle empty string or whitespace-only string | |
if not v.strip(): | |
return [] | |
try: | |
# Attempt to parse as JSON list first, in case it's provided as such | |
parsed = json.loads(v) | |
if isinstance(parsed, list): | |
return [str(item).strip() for item in parsed if str(item).strip()] | |
except json.JSONDecodeError: | |
pass # Not a JSON string, proceed to comma split | |
# Split by comma and filter out empty strings | |
return [token.strip() for token in v.split(',') if token.strip()] | |
# For any other type, try to convert to string and process | |
try: | |
str_val = str(v) | |
if not str_val or str_val.strip() == "": | |
return [] | |
return [token.strip() for token in str_val.split(',') if token.strip()] | |
except Exception: | |
return [] | |
from app.core.constants import ( | |
API_VERSION, | |
DEFAULT_CREATE_IMAGE_MODEL, | |
DEFAULT_FILTER_MODELS, | |
DEFAULT_MODEL, | |
DEFAULT_SAFETY_SETTINGS, | |
DEFAULT_STREAM_CHUNK_SIZE, | |
DEFAULT_STREAM_LONG_TEXT_THRESHOLD, | |
DEFAULT_STREAM_MAX_DELAY, | |
DEFAULT_STREAM_MIN_DELAY, | |
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, | |
DEFAULT_TIMEOUT, | |
MAX_RETRIES, | |
) | |
from app.log.logger import Logger | |
class Settings(BaseSettings): | |
# 数据库配置 | |
DATABASE_TYPE: str = "sqlite" # sqlite 或 mysql | |
SQLITE_DATABASE: str = "default_db" | |
MYSQL_HOST: str = "" | |
MYSQL_PORT: int = 3306 | |
MYSQL_USER: str = "" | |
MYSQL_PASSWORD: str = "" | |
MYSQL_DATABASE: str = "" | |
MYSQL_SOCKET: str = "" | |
# 验证 MySQL 配置 | |
def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any: | |
if info.data.get("DATABASE_TYPE") == "mysql": | |
if v is None or v == "": | |
raise ValueError( | |
"MySQL configuration is required when DATABASE_TYPE is 'mysql'" | |
) | |
return v | |
# API相关配置 | |
API_KEYS: str = "" | |
ALLOWED_TOKENS: str = "" | |
def validate_api_keys(cls, v: str) -> str: | |
if not v or v.strip() == "": | |
raise ValueError("API_KEYS cannot be empty. Please provide at least one API key.") | |
return v | |
def api_keys_list(self) -> List[str]: | |
"""将API_KEYS字符串转换为列表""" | |
return parse_comma_separated_string(self.API_KEYS) | |
def allowed_tokens_list(self) -> List[str]: | |
"""将ALLOWED_TOKENS字符串转换为列表""" | |
return parse_comma_separated_string(self.ALLOWED_TOKENS) | |
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}" | |
AUTH_TOKEN: str = "" | |
MAX_FAILURES: int = 3 | |
TEST_MODEL: str = DEFAULT_MODEL | |
TIME_OUT: int = DEFAULT_TIMEOUT | |
MAX_RETRIES: int = MAX_RETRIES | |
PROXIES: List[str] = [] | |
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理 | |
VERTEX_API_KEYS: List[str] = [] | |
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google" | |
# 智能路由配置 | |
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能 | |
# 模型相关配置 | |
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"] | |
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"] | |
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS | |
TOOLS_CODE_EXECUTION_ENABLED: bool = False | |
def parse_boolean(cls, v: Any) -> bool: | |
if isinstance(v, str): | |
return v.lower() == "true" | |
return v | |
SHOW_SEARCH_LINK: bool = True | |
SHOW_THINKING_PROCESS: bool = True | |
THINKING_MODELS: List[str] = [] | |
THINKING_BUDGET_MAP: Dict[str, float] = {} | |
# TTS相关配置 | |
TTS_MODEL: str = "gemini-2.5-flash-preview-tts" | |
TTS_VOICE_NAME: str = "Zephyr" | |
TTS_SPEED: str = "normal" | |
# 图像生成相关配置 | |
PAID_KEY: str = "" | |
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL | |
UPLOAD_PROVIDER: str = "smms" | |
SMMS_SECRET_TOKEN: str = "" | |
PICGO_API_KEY: str = "" | |
CLOUDFLARE_IMGBED_URL: str = "" | |
CLOUDFLARE_IMGBED_AUTH_CODE: str = "" | |
# 流式输出优化器配置 | |
STREAM_OPTIMIZER_ENABLED: bool = False | |
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY | |
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY | |
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD | |
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD | |
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE | |
# 假流式配置 (Fake Streaming Configuration) | |
FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出 | |
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒) | |
# 调度器配置 | |
CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时 | |
TIMEZONE: str = "Asia/Shanghai" # 默认时区 | |
# github | |
GITHUB_REPO_OWNER: str = "snailyp" | |
GITHUB_REPO_NAME: str = "gemini-balance" | |
# 日志配置 | |
LOG_LEVEL: str = "INFO" | |
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True | |
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7 | |
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False | |
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30 | |
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# 设置默认AUTH_TOKEN(如果未提供) | |
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS: | |
tokens_list = self.allowed_tokens_list | |
if tokens_list: | |
self.AUTH_TOKEN = tokens_list[0] | |
# 创建全局配置实例 | |
settings = Settings() | |
def get_settings() -> Settings: | |
"""获取配置实例""" | |
return settings | |
def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any: | |
"""尝试将数据库字符串值解析为目标 Python 类型""" | |
from app.log.logger import get_config_logger | |
logger = get_config_logger() | |
try: | |
# 处理 List[str] | |
if target_type == List[str]: | |
try: | |
parsed = json.loads(db_value) | |
if isinstance(parsed, list): | |
return [str(item) for item in parsed] | |
except json.JSONDecodeError: | |
return [item.strip() for item in db_value.split(",") if item.strip()] | |
logger.warning( | |
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list." | |
) | |
return [item.strip() for item in db_value.split(",") if item.strip()] | |
# 处理 Dict[str, float] | |
elif target_type == Dict[str, float]: | |
parsed_dict = {} | |
try: | |
parsed = json.loads(db_value) | |
if isinstance(parsed, dict): | |
parsed_dict = {str(k): float(v) for k, v in parsed.items()} | |
else: | |
logger.warning( | |
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}" | |
) | |
except (json.JSONDecodeError, ValueError, TypeError) as e1: | |
if isinstance(e1, json.JSONDecodeError) and "'" in db_value: | |
logger.warning( | |
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}" | |
) | |
try: | |
corrected_db_value = db_value.replace("'", '"') | |
parsed = json.loads(corrected_db_value) | |
if isinstance(parsed, dict): | |
parsed_dict = {str(k): float(v) for k, v in parsed.items()} | |
else: | |
logger.warning( | |
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}" | |
) | |
except (json.JSONDecodeError, ValueError, TypeError) as e2: | |
logger.error( | |
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict." | |
) | |
else: | |
logger.error( | |
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict." | |
) | |
return parsed_dict | |
# 处理 List[Dict[str, str]] | |
elif target_type == List[Dict[str, str]]: | |
try: | |
parsed = json.loads(db_value) | |
if isinstance(parsed, list): | |
# 验证列表中的每个元素是否为字典,并且键和值都是字符串 | |
valid = all( | |
isinstance(item, dict) | |
and all(isinstance(k, str) for k in item.keys()) | |
and all(isinstance(v, str) for v in item.values()) | |
for item in parsed | |
) | |
if valid: | |
return parsed | |
else: | |
logger.warning( | |
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}" | |
) | |
return [] | |
else: | |
logger.warning( | |
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}" | |
) | |
return [] | |
except json.JSONDecodeError: | |
logger.error( | |
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list." | |
) | |
return [] | |
except Exception as e: | |
logger.error( | |
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list." | |
) | |
return [] | |
# 处理 bool | |
elif target_type == bool: | |
return db_value.lower() in ("true", "1", "yes", "on") | |
# 处理 int | |
elif target_type == int: | |
return int(db_value) | |
# 处理 float | |
elif target_type == float: | |
return float(db_value) | |
# 默认为 str 或其他 pydantic 能直接处理的类型 | |
else: | |
return db_value | |
except (ValueError, TypeError, json.JSONDecodeError) as e: | |
logger.warning( | |
f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value." | |
) | |
return db_value # 解析失败则返回原始字符串 | |
async def sync_initial_settings(): | |
""" | |
应用启动时同步配置: | |
1. 从数据库加载设置。 | |
2. 将数据库设置合并到内存 settings (数据库优先)。 | |
3. 将最终的内存 settings 同步回数据库。 | |
""" | |
from app.log.logger import get_config_logger | |
logger = get_config_logger() | |
# 延迟导入以避免循环依赖和确保数据库连接已初始化 | |
from app.database.connection import database | |
from app.database.models import Settings as SettingsModel | |
global settings | |
logger.info("Starting initial settings synchronization...") | |
if not database.is_connected: | |
try: | |
await database.connect() | |
logger.info("Database connection established for initial sync.") | |
except Exception as e: | |
logger.error( | |
f"Failed to connect to database for initial settings sync: {e}. Skipping sync." | |
) | |
return | |
try: | |
# 1. 从数据库加载设置 | |
db_settings_raw: List[Dict[str, Any]] = [] | |
try: | |
query = select(SettingsModel.key, SettingsModel.value) | |
results = await database.fetch_all(query) | |
db_settings_raw = [ | |
{"key": row["key"], "value": row["value"]} for row in results | |
] | |
logger.info(f"Fetched {len(db_settings_raw)} settings from database.") | |
except Exception as e: | |
logger.error( | |
f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings." | |
) | |
# 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库 | |
db_settings_map: Dict[str, str] = { | |
s["key"]: s["value"] for s in db_settings_raw | |
} | |
# 2. 将数据库设置合并到内存 settings (数据库优先) | |
updated_in_memory = False | |
for key, db_value in db_settings_map.items(): | |
if key == "DATABASE_TYPE": | |
logger.debug( | |
f"Skipping update of '{key}' in memory from database. " | |
"This setting is controlled by environment/dotenv." | |
) | |
continue | |
if hasattr(settings, key): | |
target_type = Settings.__annotations__.get(key) | |
if target_type: | |
try: | |
parsed_db_value = _parse_db_value(key, db_value, target_type) | |
memory_value = getattr(settings, key) | |
# 比较解析后的值和内存中的值 | |
# 注意:对于列表等复杂类型,直接比较可能不够健壮,但这里简化处理 | |
if parsed_db_value != memory_value: | |
# 检查类型是否匹配,以防解析函数返回了不兼容的类型 | |
type_match = False | |
if target_type == List[str] and isinstance( | |
parsed_db_value, list | |
): | |
type_match = True | |
elif target_type == Dict[str, float] and isinstance( | |
parsed_db_value, dict | |
): | |
type_match = True | |
elif target_type not in ( | |
List[str], | |
Dict[str, float], | |
) and isinstance(parsed_db_value, target_type): | |
type_match = True | |
if type_match: | |
setattr(settings, key, parsed_db_value) | |
logger.debug( | |
f"Updated setting '{key}' in memory from database value ({target_type})." | |
) | |
updated_in_memory = True | |
else: | |
logger.warning( | |
f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update." | |
) | |
except Exception as e: | |
logger.error( | |
f"Error processing database setting for key '{key}': {e}" | |
) | |
else: | |
logger.warning( | |
f"Database setting '{key}' not found in Settings model definition. Ignoring." | |
) | |
# 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐) | |
if updated_in_memory: | |
try: | |
# 重新加载以确保类型转换和验证 | |
settings = Settings(**settings.model_dump()) | |
logger.info( | |
"Settings object re-validated after merging database values." | |
) | |
except ValidationError as e: | |
logger.error( | |
f"Validation error after merging database settings: {e}. Settings might be inconsistent." | |
) | |
# 3. 将最终的内存 settings 同步回数据库 | |
final_memory_settings = settings.model_dump() | |
settings_to_update: List[Dict[str, Any]] = [] | |
settings_to_insert: List[Dict[str, Any]] = [] | |
now = datetime.datetime.now(datetime.timezone.utc) | |
existing_db_keys = set(db_settings_map.keys()) | |
for key, value in final_memory_settings.items(): | |
if key == "DATABASE_TYPE": | |
logger.debug( | |
f"Skipping synchronization of '{key}' to database. " | |
"This setting is controlled by environment/dotenv." | |
) | |
continue | |
# 序列化值为字符串或 JSON 字符串 | |
if isinstance(value, (list, dict)): | |
db_value = json.dumps( | |
value, ensure_ascii=False | |
) | |
elif isinstance(value, bool): | |
db_value = str(value).lower() | |
elif value is None: | |
db_value = "" | |
else: | |
db_value = str(value) | |
data = { | |
"key": key, | |
"value": db_value, | |
"description": f"{key} configuration setting", | |
"updated_at": now, | |
} | |
if key in existing_db_keys: | |
# 仅当值与数据库中的不同时才更新 | |
if db_settings_map[key] != db_value: | |
settings_to_update.append(data) | |
else: | |
# 如果键不在数据库中,则插入 | |
data["created_at"] = now | |
settings_to_insert.append(data) | |
# 在事务中执行批量插入和更新 | |
if settings_to_insert or settings_to_update: | |
try: | |
async with database.transaction(): | |
if settings_to_insert: | |
# 获取现有描述以避免覆盖 | |
query_existing = select( | |
SettingsModel.key, SettingsModel.description | |
).where( | |
SettingsModel.key.in_( | |
[s["key"] for s in settings_to_insert] | |
) | |
) | |
existing_desc = { | |
row["key"]: row["description"] | |
for row in await database.fetch_all(query_existing) | |
} | |
for item in settings_to_insert: | |
item["description"] = existing_desc.get( | |
item["key"], item["description"] | |
) | |
query_insert = insert(SettingsModel).values(settings_to_insert) | |
await database.execute(query=query_insert) | |
logger.info( | |
f"Synced (inserted) {len(settings_to_insert)} settings to database." | |
) | |
if settings_to_update: | |
# 获取现有描述以避免覆盖 | |
query_existing = select( | |
SettingsModel.key, SettingsModel.description | |
).where( | |
SettingsModel.key.in_( | |
[s["key"] for s in settings_to_update] | |
) | |
) | |
existing_desc = { | |
row["key"]: row["description"] | |
for row in await database.fetch_all(query_existing) | |
} | |
for setting_data in settings_to_update: | |
setting_data["description"] = existing_desc.get( | |
setting_data["key"], setting_data["description"] | |
) | |
query_update = ( | |
update(SettingsModel) | |
.where(SettingsModel.key == setting_data["key"]) | |
.values( | |
value=setting_data["value"], | |
description=setting_data["description"], | |
updated_at=setting_data["updated_at"], | |
) | |
) | |
await database.execute(query=query_update) | |
logger.info( | |
f"Synced (updated) {len(settings_to_update)} settings to database." | |
) | |
except Exception as e: | |
logger.error( | |
f"Failed to sync settings to database during startup: {str(e)}" | |
) | |
else: | |
logger.info( | |
"No setting changes detected between memory and database during initial sync." | |
) | |
# 刷新日志等级 | |
Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL")) | |
except Exception as e: | |
logger.error(f"An unexpected error occurred during initial settings sync: {e}") | |
finally: | |
if database.is_connected: | |
try: | |
pass | |
except Exception as e: | |
logger.error(f"Error disconnecting database after initial sync: {e}") | |
logger.info("Initial settings synchronization finished.") | |