File size: 4,966 Bytes
76b9762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
通用工具函数模块
"""
import json
import re
import base64
import requests
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
import logging

from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS

helper_logger = logging.getLogger("app.utils")

PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"


def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
    """
    从 base64 字符串中提取 MIME 类型和数据
    
    Args:
        base64_string: 可能包含 MIME 类型信息的 base64 字符串
        
    Returns:
        tuple: (mime_type, encoded_data)
    """
    # 检查字符串是否以 "data:" 格式开始
    if base64_string.startswith('data:'):
        # 提取 MIME 类型和数据
        pattern = DATA_URL_PATTERN
        match = re.match(pattern, base64_string)
        if match:
            mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
            encoded_data = match.group(2)
            return mime_type, encoded_data
    
    # 如果不是预期格式,假定它只是数据部分
    return None, base64_string


def convert_image_to_base64(url: str) -> str:
    """
    将图片URL转换为base64编码
    
    Args:
        url: 图片URL
        
    Returns:
        str: base64编码的图片数据
        
    Raises:
        Exception: 如果获取图片失败
    """
    response = requests.get(url)
    if response.status_code == 200:
        # 将图片内容转换为base64
        img_data = base64.b64encode(response.content).decode('utf-8')
        return img_data
    else:
        raise Exception(f"Failed to fetch image: {response.status_code}")


def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
    """
    格式化JSON响应
    
    Args:
        data: 要格式化的数据
        indent: 缩进空格数
        
    Returns:
        str: 格式化后的JSON字符串
    """
    return json.dumps(data, indent=indent, ensure_ascii=False)


def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
    """
    从prompt中解析参数
    
    支持的格式:
    - {n:数量} 例如: {n:2} 生成2张图片
    - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
    
    Args:
        prompt: 提示文本
        default_ratio: 默认比例
        
    Returns:
        tuple: (清理后的提示文本, 图片数量, 比例)
    """
    # 默认值
    n = 1
    aspect_ratio = default_ratio
    
    # 解析n参数
    n_match = re.search(r'{n:(\d+)}', prompt)
    if n_match:
        n = int(n_match.group(1))
        if n < 1 or n > 4:
            raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
        prompt = prompt.replace(n_match.group(0), '').strip()
        
    # 解析ratio参数    
    ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
    if ratio_match:
        aspect_ratio = ratio_match.group(1)
        if aspect_ratio not in VALID_IMAGE_RATIOS:
            raise ValueError(
                f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
            )
        prompt = prompt.replace(ratio_match.group(0), '').strip()
        
    return prompt, n, aspect_ratio


def extract_image_urls_from_markdown(text: str) -> List[str]:
    """
    从Markdown文本中提取图片URL
    
    Args:
        text: Markdown文本
        
    Returns:
        List[str]: 图片URL列表
    """
    pattern = IMAGE_URL_PATTERN
    matches = re.findall(pattern, text)
    return [match[1] for match in matches]


def is_valid_api_key(key: str) -> bool:
    """
    检查API密钥格式是否有效
    
    Args:
        key: API密钥
        
    Returns:
        bool: 如果密钥格式有效则返回True
    """
    # 检查Gemini API密钥格式
    if key.startswith('AIza'):
        return len(key) >= 30
    
    # 检查OpenAI API密钥格式
    if key.startswith('sk-'):
        return len(key) >= 30
    
    return False



def get_current_version(default_version: str = "0.0.0") -> str:
    """Reads the current version from the VERSION file."""
    version_file = VERSION_FILE_PATH
    try:
        with version_file.open('r', encoding='utf-8') as f:
            version = f.read().strip()
        if not version:
            helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
            return default_version
        return version
    except FileNotFoundError:
        helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
        return default_version
    except IOError as e:
        helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
        return default_version