import base64 from io import BytesIO import random import time import os import numpy as np import tritonclient.http as httpclient from PIL import Image import gradio as gr from openai import OpenAI url = "1893706806886638.cn-beijing.pai-eas.aliyuncs.com/api/predict/prod_ad_fluxtritondeploy_1120" authorization = "ODdhMGYxNmI1ZjJhN2E0NDEwM2QyZjcyYTlhY2UxZmZjNWY2M2FmZQ==" def random_seed() -> int: random.seed(time.time()) seed = int(random.randrange(4294967294)) return seed def enhance_prompt(system_prompt, user_prompt): client = OpenAI(api_key="sk-rOjB00dtKBbSYIfgewn_KA", base_url="https://internal-skyscriptllm.skyreels.ai") response = client.chat.completions.create( model="gpt-4o-2024-08-06", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] ) return response.choices[0].message.content def generate_image(system_prompt, user_prompt, seed, height, width): # First enhance the prompt enhanced_prompt = enhance_prompt(system_prompt, user_prompt) # Then generate image using enhanced prompt triton_client = httpclient.InferenceServerClient(url=url, verbose=False, concurrency=2) class MyPlugin: def __call__(self, request): request.headers["Authorization"] = authorization my_plugin = MyPlugin() triton_client.register_plugin(my_plugin) if seed == 0 or seed is None: seed = random_seed() inputs_dict = { "request_type": "text2img", "prompt": enhanced_prompt, "seed": seed, "height": height, "width": width, "face": "", } print("Original prompt:", user_prompt) print("Enhanced prompt:", enhanced_prompt) print(inputs_dict) inputs = [] for name, data in inputs_dict.items(): if type(data) is str: bytes_data = np.array([data.encode("utf-8")], dtype=np.object_).reshape([1, -1]) input = httpclient.InferInput(name, bytes_data.shape, "BYTES") elif type(data) is int: bytes_data = np.array([data], dtype=np.int64).reshape(1, 1) input = httpclient.InferInput(name, bytes_data.shape, "INT64") input.set_data_from_numpy(bytes_data) inputs.append(input) outputs = [ httpclient.InferRequestedOutput("o_image", binary_data=True), ] async_request = triton_client.async_infer( model_name="ensemble", inputs=inputs, outputs=outputs, timeout=60 * 20, ) result = async_request.get_result() img = result.as_numpy("o_image")[0] buff = BytesIO(base64.b64decode(img)) image = Image.open(buff) return enhanced_prompt, image def launch_interface(): # Set specific directory for gradio cache/temp files gradio_temp_dir = "/maindata/data/shared/public/guibin.chen/gradio_cache" os.makedirs(gradio_temp_dir, exist_ok=True) os.environ["GRADIO_TEMP_DIR"] = gradio_temp_dir default_system_prompt = "You are an expert at writing detailed, creative and vivid image generation prompts. Enhance the user's prompt by adding more details and artistic direction while maintaining their original intent." interface = gr.Interface( fn=generate_image, inputs=[ gr.Textbox(label="System Prompt", value=default_system_prompt, lines=3), gr.Textbox(label="User Prompt"), gr.Number(label="Seed (0 for random)", value=0, precision=0), gr.Slider(minimum=64, maximum=2048, value=720, step=8, label="Height"), gr.Slider(minimum=64, maximum=2048, value=1280, step=8, label="Width") ], outputs=[ gr.Textbox(label="Enhanced Prompt"), gr.Image(type="pil", label="Generated Image") ], title="Enhanced Image Generation Interface", description="Generate images from text prompts with AI enhancement" ) interface.launch(server_name="0.0.0.0", server_port=7890, share=True) if __name__ == "__main__": launch_interface()