|
import base64 |
|
from io import BytesIO |
|
import random |
|
import time |
|
import os |
|
|
|
import numpy as np |
|
import tritonclient.http as httpclient |
|
from PIL import Image |
|
import gradio as gr |
|
from openai import OpenAI |
|
|
|
url = "1893706806886638.cn-beijing.pai-eas.aliyuncs.com/api/predict/prod_ad_fluxtritondeploy_1120" |
|
authorization = "ODdhMGYxNmI1ZjJhN2E0NDEwM2QyZjcyYTlhY2UxZmZjNWY2M2FmZQ==" |
|
|
|
def random_seed() -> int: |
|
random.seed(time.time()) |
|
seed = int(random.randrange(4294967294)) |
|
return seed |
|
|
|
def enhance_prompt(system_prompt, user_prompt): |
|
client = OpenAI(api_key="sk-rOjB00dtKBbSYIfgewn_KA", base_url="https://internal-skyscriptllm.skyreels.ai") |
|
response = client.chat.completions.create( |
|
model="gpt-4o-2024-08-06", |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt} |
|
] |
|
) |
|
return response.choices[0].message.content |
|
|
|
def generate_image(system_prompt, user_prompt, seed, height, width): |
|
|
|
enhanced_prompt = enhance_prompt(system_prompt, user_prompt) |
|
|
|
|
|
triton_client = httpclient.InferenceServerClient(url=url, verbose=False, concurrency=2) |
|
class MyPlugin: |
|
def __call__(self, request): |
|
request.headers["Authorization"] = authorization |
|
my_plugin = MyPlugin() |
|
triton_client.register_plugin(my_plugin) |
|
|
|
if seed == 0 or seed is None: |
|
seed = random_seed() |
|
|
|
inputs_dict = { |
|
"request_type": "text2img", |
|
"prompt": enhanced_prompt, |
|
"seed": seed, |
|
"height": height, |
|
"width": width, |
|
"face": "", |
|
} |
|
print("Original prompt:", user_prompt) |
|
print("Enhanced prompt:", enhanced_prompt) |
|
print(inputs_dict) |
|
|
|
inputs = [] |
|
for name, data in inputs_dict.items(): |
|
if type(data) is str: |
|
bytes_data = np.array([data.encode("utf-8")], dtype=np.object_).reshape([1, -1]) |
|
input = httpclient.InferInput(name, bytes_data.shape, "BYTES") |
|
elif type(data) is int: |
|
bytes_data = np.array([data], dtype=np.int64).reshape(1, 1) |
|
input = httpclient.InferInput(name, bytes_data.shape, "INT64") |
|
input.set_data_from_numpy(bytes_data) |
|
inputs.append(input) |
|
outputs = [ |
|
httpclient.InferRequestedOutput("o_image", binary_data=True), |
|
] |
|
async_request = triton_client.async_infer( |
|
model_name="ensemble", |
|
inputs=inputs, |
|
outputs=outputs, |
|
timeout=60 * 20, |
|
) |
|
|
|
result = async_request.get_result() |
|
img = result.as_numpy("o_image")[0] |
|
buff = BytesIO(base64.b64decode(img)) |
|
image = Image.open(buff) |
|
return enhanced_prompt, image |
|
|
|
def launch_interface(): |
|
|
|
gradio_temp_dir = "/maindata/data/shared/public/guibin.chen/gradio_cache" |
|
os.makedirs(gradio_temp_dir, exist_ok=True) |
|
os.environ["GRADIO_TEMP_DIR"] = gradio_temp_dir |
|
|
|
default_system_prompt = "You are an expert at writing detailed, creative and vivid image generation prompts. Enhance the user's prompt by adding more details and artistic direction while maintaining their original intent." |
|
|
|
interface = gr.Interface( |
|
fn=generate_image, |
|
inputs=[ |
|
gr.Textbox(label="System Prompt", value=default_system_prompt, lines=3), |
|
gr.Textbox(label="User Prompt"), |
|
gr.Number(label="Seed (0 for random)", value=0, precision=0), |
|
gr.Slider(minimum=64, maximum=2048, value=720, step=8, label="Height"), |
|
gr.Slider(minimum=64, maximum=2048, value=1280, step=8, label="Width") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Enhanced Prompt"), |
|
gr.Image(type="pil", label="Generated Image") |
|
], |
|
title="Enhanced Image Generation Interface", |
|
description="Generate images from text prompts with AI enhancement" |
|
) |
|
interface.launch(server_name="0.0.0.0", server_port=7890, share=True) |
|
|
|
if __name__ == "__main__": |
|
launch_interface() |
|
|