Spaces:
Running
Running
File size: 5,649 Bytes
3b13b0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import tomli
from loguru import logger
from typing import Dict, Any, Optional
from dataclasses import dataclass
def get_version_from_file():
"""从project_version文件中读取版本号"""
try:
version_file = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"project_version"
)
if os.path.isfile(version_file):
with open(version_file, "r", encoding="utf-8") as f:
return f.read().strip()
return "0.1.0" # 默认版本号
except Exception as e:
logger.error(f"读取版本号文件失败: {str(e)}")
return "0.1.0" # 默认版本号
@dataclass
class WebUIConfig:
"""WebUI配置类"""
# UI配置
ui: Dict[str, Any] = None
# 代理配置
proxy: Dict[str, str] = None
# 应用配置
app: Dict[str, Any] = None
# Azure配置
azure: Dict[str, str] = None
# 项目版本
project_version: str = get_version_from_file()
# 项目根目录
root_dir: str = None
# Gemini API Key
gemini_api_key: str = ""
# 每批处理的图片数量
vision_batch_size: int = 5
# 提示词
vision_prompt: str = """..."""
# Narrato API 配置
narrato_api_url: str = "http://127.0.0.1:8000/api/v1/video/analyze"
narrato_api_key: str = ""
narrato_batch_size: int = 10
narrato_vision_model: str = "gemini-1.5-flash"
narrato_llm_model: str = "qwen-plus"
def __post_init__(self):
"""初始化默认值"""
self.ui = self.ui or {}
self.proxy = self.proxy or {}
self.app = self.app or {}
self.azure = self.azure or {}
self.root_dir = self.root_dir or os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
def load_config(config_path: Optional[str] = None) -> WebUIConfig:
"""加载配置文件
Args:
config_path: 配置文件路径,如果为None则使用默认路径
Returns:
WebUIConfig: 配置对象
"""
try:
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
".streamlit",
"webui.toml"
)
# 如果配置文件不存在,使用示例配置
if not os.path.exists(config_path):
example_config = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"config.example.toml"
)
if os.path.exists(example_config):
config_path = example_config
else:
logger.warning(f"配置文件不存在: {config_path}")
return WebUIConfig()
# 读取配置文件
with open(config_path, "rb") as f:
config_dict = tomli.load(f)
# 创建配置对象,使用从文件读取的版本号
config = WebUIConfig(
ui=config_dict.get("ui", {}),
proxy=config_dict.get("proxy", {}),
app=config_dict.get("app", {}),
azure=config_dict.get("azure", {}),
# 不再从配置文件中获取project_version
)
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return WebUIConfig()
def save_config(config: WebUIConfig, config_path: Optional[str] = None) -> bool:
"""保存配置到文件
Args:
config: 配置对象
config_path: 配置文件路径,如果为None则使用默认路径
Returns:
bool: 是否保存成功
"""
try:
if config_path is None:
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
".streamlit",
"webui.toml"
)
# 确保目录存在
os.makedirs(os.path.dirname(config_path), exist_ok=True)
# 转换为字典,不再保存版本号到配置文件
config_dict = {
"ui": config.ui,
"proxy": config.proxy,
"app": config.app,
"azure": config.azure
# 不再保存project_version到配置文件
}
# 保存配置
with open(config_path, "w", encoding="utf-8") as f:
import tomli_w
tomli_w.dump(config_dict, f)
return True
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
return False
def get_config() -> WebUIConfig:
"""获取全局配置对象
Returns:
WebUIConfig: 配置对象
"""
if not hasattr(get_config, "_config"):
get_config._config = load_config()
return get_config._config
def update_config(config_dict: Dict[str, Any]) -> bool:
"""更新配置
Args:
config_dict: 配置字典
Returns:
bool: 是否更新成功
"""
try:
config = get_config()
# 更新配置
if "ui" in config_dict:
config.ui.update(config_dict["ui"])
if "proxy" in config_dict:
config.proxy.update(config_dict["proxy"])
if "app" in config_dict:
config.app.update(config_dict["app"])
if "azure" in config_dict:
config.azure.update(config_dict["azure"])
# 不再从配置字典更新project_version
# 保存配置
return save_config(config)
except Exception as e:
logger.error(f"更新配置失败: {e}")
return False
# 导出全局配置对象
config = get_config() |