File size: 4,108 Bytes
e9fa53a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import base64
from io import BytesIO
import random
import time
import os

import numpy as np
import tritonclient.http as httpclient
from PIL import Image
import gradio as gr
from openai import OpenAI

url = "1893706806886638.cn-beijing.pai-eas.aliyuncs.com/api/predict/prod_ad_fluxtritondeploy_1120"
authorization = "ODdhMGYxNmI1ZjJhN2E0NDEwM2QyZjcyYTlhY2UxZmZjNWY2M2FmZQ=="

def random_seed() -> int:
    random.seed(time.time())
    seed = int(random.randrange(4294967294))
    return seed

def enhance_prompt(system_prompt, user_prompt):
    client = OpenAI(api_key="sk-rOjB00dtKBbSYIfgewn_KA", base_url="https://internal-skyscriptllm.skyreels.ai")
    response = client.chat.completions.create(
        model="gpt-4o-2024-08-06", 
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )
    return response.choices[0].message.content

def generate_image(system_prompt, user_prompt, seed, height, width):
    # First enhance the prompt
    enhanced_prompt = enhance_prompt(system_prompt, user_prompt)
    
    # Then generate image using enhanced prompt
    triton_client = httpclient.InferenceServerClient(url=url, verbose=False, concurrency=2)
    class MyPlugin:
        def __call__(self, request):
            request.headers["Authorization"] = authorization
    my_plugin = MyPlugin()
    triton_client.register_plugin(my_plugin)
    
    if seed == 0 or seed is None:
        seed = random_seed()
        
    inputs_dict = {
        "request_type": "text2img",
        "prompt": enhanced_prompt,
        "seed": seed,
        "height": height,
        "width": width,
        "face": "",
    }
    print("Original prompt:", user_prompt)
    print("Enhanced prompt:", enhanced_prompt)
    print(inputs_dict)
    
    inputs = []
    for name, data in inputs_dict.items():
        if type(data) is str:
            bytes_data = np.array([data.encode("utf-8")], dtype=np.object_).reshape([1, -1])
            input = httpclient.InferInput(name, bytes_data.shape, "BYTES")
        elif type(data) is int:
            bytes_data = np.array([data], dtype=np.int64).reshape(1, 1)
            input = httpclient.InferInput(name, bytes_data.shape, "INT64")
        input.set_data_from_numpy(bytes_data)
        inputs.append(input)
    outputs = [
        httpclient.InferRequestedOutput("o_image", binary_data=True),
    ]
    async_request = triton_client.async_infer(
        model_name="ensemble",
        inputs=inputs,
        outputs=outputs,
        timeout=60 * 20,
    )

    result = async_request.get_result()
    img = result.as_numpy("o_image")[0]
    buff = BytesIO(base64.b64decode(img))
    image = Image.open(buff)
    return enhanced_prompt, image

def launch_interface():
    # Set specific directory for gradio cache/temp files
    gradio_temp_dir = "/maindata/data/shared/public/guibin.chen/gradio_cache"
    os.makedirs(gradio_temp_dir, exist_ok=True)
    os.environ["GRADIO_TEMP_DIR"] = gradio_temp_dir
    
    default_system_prompt = "You are an expert at writing detailed, creative and vivid image generation prompts. Enhance the user's prompt by adding more details and artistic direction while maintaining their original intent."
    
    interface = gr.Interface(
        fn=generate_image,
        inputs=[
            gr.Textbox(label="System Prompt", value=default_system_prompt, lines=3),
            gr.Textbox(label="User Prompt"),
            gr.Number(label="Seed (0 for random)", value=0, precision=0),
            gr.Slider(minimum=64, maximum=2048, value=720, step=8, label="Height"),
            gr.Slider(minimum=64, maximum=2048, value=1280, step=8, label="Width")
        ],
        outputs=[
            gr.Textbox(label="Enhanced Prompt"),
            gr.Image(type="pil", label="Generated Image")
        ],
        title="Enhanced Image Generation Interface",
        description="Generate images from text prompts with AI enhancement"
    )
    interface.launch(server_name="0.0.0.0", server_port=7890, share=True)

if __name__ == "__main__":
    launch_interface()