File size: 6,279 Bytes
76b9762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import base64
import time
import uuid

from google import genai
from google.genai import types

from app.config.config import settings
from app.core.constants import VALID_IMAGE_RATIOS
from app.domain.openai_models import ImageGenerationRequest
from app.log.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory

logger = get_image_create_logger()


class ImageCreateService:
    def __init__(self, aspect_ratio="1:1"):
        self.image_model = settings.CREATE_IMAGE_MODEL
        self.aspect_ratio = aspect_ratio

    def parse_prompt_parameters(self, prompt: str) -> tuple:
        """从prompt中解析参数
        支持的格式:
        - {n:数量} 例如: {n:2} 生成2张图片
        - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
        """
        import re

        # 默认值
        n = 1
        aspect_ratio = self.aspect_ratio

        # 解析n参数
        n_match = re.search(r"{n:(\d+)}", prompt)
        if n_match:
            n = int(n_match.group(1))
            if n < 1 or n > 4:
                raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
            prompt = prompt.replace(n_match.group(0), "").strip()

        # 解析ratio参数
        ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
        if ratio_match:
            aspect_ratio = ratio_match.group(1)
            if aspect_ratio not in VALID_IMAGE_RATIOS:
                raise ValueError(
                    f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
                )
            prompt = prompt.replace(ratio_match.group(0), "").strip()

        return prompt, n, aspect_ratio

    def generate_images(self, request: ImageGenerationRequest):
        client = genai.Client(api_key=settings.PAID_KEY)

        if request.size == "1024x1024":
            self.aspect_ratio = "1:1"
        elif request.size == "1792x1024":
            self.aspect_ratio = "16:9"
        elif request.size == "1027x1792":
            self.aspect_ratio = "9:16"
        else:
            raise ValueError(
                f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
            )

        # 解析prompt中的参数
        cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
            request.prompt
        )
        request.prompt = cleaned_prompt

        # 如果prompt中指定了n,则覆盖请求中的n
        if prompt_n > 1:
            request.n = prompt_n

        # 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
        if prompt_ratio != self.aspect_ratio:
            self.aspect_ratio = prompt_ratio

        response = client.models.generate_images(
            model=self.image_model,
            prompt=request.prompt,
            config=types.GenerateImagesConfig(
                number_of_images=request.n,
                output_mime_type="image/png",
                aspect_ratio=self.aspect_ratio,
                safety_filter_level="BLOCK_LOW_AND_ABOVE",
                person_generation="ALLOW_ADULT",
            ),
        )

        if response.generated_images:
            images_data = []
            for index, generated_image in enumerate(response.generated_images):
                image_data = generated_image.image.image_bytes
                image_uploader = None

                if request.response_format == "b64_json":
                    base64_image = base64.b64encode(image_data).decode("utf-8")
                    images_data.append(
                        {"b64_json": base64_image, "revised_prompt": request.prompt}
                    )
                else:
                    current_date = time.strftime("%Y/%m/%d")
                    filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"

                    if settings.UPLOAD_PROVIDER == "smms":
                        image_uploader = ImageUploaderFactory.create(
                            provider=settings.UPLOAD_PROVIDER,
                            api_key=settings.SMMS_SECRET_TOKEN,
                        )
                    elif settings.UPLOAD_PROVIDER == "picgo":
                        image_uploader = ImageUploaderFactory.create(
                            provider=settings.UPLOAD_PROVIDER,
                            api_key=settings.PICGO_API_KEY,
                        )
                    elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
                        image_uploader = ImageUploaderFactory.create(
                            provider=settings.UPLOAD_PROVIDER,
                            base_url=settings.CLOUDFLARE_IMGBED_URL,
                            auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
                        )
                    else:
                        raise ValueError(
                            f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
                        )

                    upload_response = image_uploader.upload(image_data, filename)

                    images_data.append(
                        {
                            "url": f"{upload_response.data.url}",
                            "revised_prompt": request.prompt,
                        }
                    )

            response_data = {
                "created": int(time.time()),
                "data": images_data,
            }
            return response_data
        else:
            raise Exception("I can't generate these images")

    def generate_images_chat(self, request: ImageGenerationRequest) -> str:
        response = self.generate_images(request)
        image_datas = response["data"]
        if image_datas:
            markdown_images = []
            for index, image_data in enumerate(image_datas):
                if "url" in image_data:
                    markdown_images.append(
                        f"![Generated Image {index+1}]({image_data['url']})"
                    )
                else:
                    # 如果是base64格式,创建data URL
                    markdown_images.append(
                        f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})"
                    )
            return "\n".join(markdown_images)