File size: 5,661 Bytes
b5e7375
 
 
 
 
e96f98e
 
 
 
 
b5e7375
e96f98e
b5e7375
 
 
 
 
 
e96f98e
b5e7375
 
e96f98e
b5e7375
 
e96f98e
 
b5e7375
 
e96f98e
b5e7375
e96f98e
 
 
 
 
 
b5e7375
 
e96f98e
b5e7375
e96f98e
 
 
 
 
 
 
b5e7375
 
e96f98e
 
b5e7375
 
e96f98e
 
 
 
 
 
 
b5e7375
 
e96f98e
b5e7375
 
e96f98e
b5e7375
e96f98e
b5e7375
 
 
e96f98e
b5e7375
 
e96f98e
b5e7375
 
e96f98e
b5e7375
 
e96f98e
 
b5e7375
e96f98e
b5e7375
e96f98e
 
 
 
 
 
b5e7375
 
e96f98e
b5e7375
e96f98e
b5e7375
e96f98e
b5e7375
 
 
 
e96f98e
b5e7375
 
e96f98e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#
# SPDX-FileCopyrightText: Hadad <[email protected]>
# SPDX-License-Identifier: Apache-2.0
#

import asyncio  # Import asyncio for asynchronous programming and managing event loops
import httpx  # Import httpx for async HTTP requests with HTTP/1.1 and HTTP/2 support
import aiohttp  # Import aiohttp for alternative async HTTP client capabilities
from urllib.parse import quote  # Import quote to safely encode URL path components
from typing import Optional  # Import Optional for type hinting parameters that may be None
from src.utils.ip_generator import generate_ip  # Import custom utility to generate random IP addresses for request headers
from src.utils.tools import initialize_tools  # Import utility to initialize and get tool endpoints

# Define a class named ImageGeneration to encapsulate functionalities related to generating image content
class ImageGeneration:
    # This class provides methods to create image files based on text instructions

    """
    Class to handle asynchronous image generation requests to an external service.

    Attributes:
        FORMATS (dict): Maps image format names to (width, height) tuples.

    Methods:
        create_image: Async method to generate an image URL from a text prompt,
                      retrying until successful, using httpx and aiohttp.
    """

    # Supported image formats with their dimensions (width, height)
    FORMATS = {
        "default": (1024, 1024),
        "square": (1024, 1024),
        "landscape": (1024, 768),
        "landscape_large": (1440, 1024),
        "portrait": (768, 1024),
        "portrait_large": (1024, 1440),
    }

    @staticmethod
    async def create_image(
        generate_image_instruction: str,  # Text description for the image to generate
        image_format: str = "default",  # Format key from FORMATS dict
        model: Optional[str] = "flux-realism",  # Model name for generation, default 'flux-realism'
        seed: Optional[int] = None,  # Optional seed for reproducible randomness
        nologo: bool = True,  # Whether to exclude logo watermark
        private: bool = True,  # Whether the image should be private
        enhance: bool = True,  # Whether to apply enhancement filters
    ) -> str:
        """
        Asynchronously generate an image URL by sending requests to the image generation service.
        Uses httpx for initial requests and aiohttp as fallback, retrying indefinitely until success.

        Args:
            generate_image_instruction (str): Text prompt describing the desired image.
            image_format (str): Key for image dimensions.
            model (Optional[str]): Model to use for generation.
            seed (Optional[int]): Seed for randomization control.
            nologo (bool): Flag to exclude logo watermark.
            private (bool): Flag to mark image as private.
            enhance (bool): Flag to apply image enhancement.

        Returns:
            str: URL of the generated image on success.

        Raises:
            ValueError: If image_format is invalid.
        """
        # Validate image format key
        if image_format not in ImageGeneration.FORMATS:
            raise ValueError("Invalid image format.")

        # Extract width and height for the requested format
        width, height = ImageGeneration.FORMATS[image_format]

        # Initialize tools and get image generation service endpoint URL
        _, image_tool, _ = initialize_tools()

        # Encode instruction safely for URL path usage
        generate_image_instruct = quote(generate_image_instruction)

        # Construct the full URL endpoint for image generation
        url = f"{image_tool}{generate_image_instruct}"

        # Prepare query parameters with image size, model, flags as strings
        params = {
            "width": width,
            "height": height,
            "model": model,
            "nologo": "true" if nologo else "false",
            "private": "true" if private else "false",
            "enhance": "true" if enhance else "false",
        }

        # Add seed parameter if provided
        if seed is not None:
            params["seed"] = seed

        # Prepare headers
        headers = {
            "X-Forwarded-For": generate_ip()  # Random IP address for request header to simulate client origin
        }

        # Use httpx.AsyncClient with no timeout for initial requests
        async with httpx.AsyncClient(timeout=None) as client:
            while True:
                try:
                    # Send GET request to the image generation endpoint
                    resp = await client.get(url, params=params, headers=headers)

                    # If response is successful, return the final URL
                    if resp.status_code == 200:
                        return str(resp.url)
                except httpx.HTTPError:
                    # On httpx errors, fallback to aiohttp for robustness
                    pass

                # Fallback retry with aiohttp client
                async with aiohttp.ClientSession() as session:
                    try:
                        async with session.get(url, params=params, headers=headers) as resp:
                            if resp.status == 200:
                                # Return the final URL (aiohttp does not provide direct URL property)
                                return str(resp.url)
                    except aiohttp.ClientError:
                        # Ignore aiohttp errors and retry
                        pass

                # Wait 15 seconds before retrying to avoid overwhelming the server
                await asyncio.sleep(15)