gbc-backup / content /test_flux_gradio.py
sicer's picture
Initial commit from existing repo
e9fa53a
import base64
from io import BytesIO
import random
import time
import os
import numpy as np
import tritonclient.http as httpclient
from PIL import Image
import gradio as gr
from openai import OpenAI
url = "1893706806886638.cn-beijing.pai-eas.aliyuncs.com/api/predict/prod_ad_fluxtritondeploy_1120"
authorization = "ODdhMGYxNmI1ZjJhN2E0NDEwM2QyZjcyYTlhY2UxZmZjNWY2M2FmZQ=="
def random_seed() -> int:
random.seed(time.time())
seed = int(random.randrange(4294967294))
return seed
def enhance_prompt(system_prompt, user_prompt):
client = OpenAI(api_key="sk-rOjB00dtKBbSYIfgewn_KA", base_url="https://internal-skyscriptllm.skyreels.ai")
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
return response.choices[0].message.content
def generate_image(system_prompt, user_prompt, seed, height, width):
# First enhance the prompt
enhanced_prompt = enhance_prompt(system_prompt, user_prompt)
# Then generate image using enhanced prompt
triton_client = httpclient.InferenceServerClient(url=url, verbose=False, concurrency=2)
class MyPlugin:
def __call__(self, request):
request.headers["Authorization"] = authorization
my_plugin = MyPlugin()
triton_client.register_plugin(my_plugin)
if seed == 0 or seed is None:
seed = random_seed()
inputs_dict = {
"request_type": "text2img",
"prompt": enhanced_prompt,
"seed": seed,
"height": height,
"width": width,
"face": "",
}
print("Original prompt:", user_prompt)
print("Enhanced prompt:", enhanced_prompt)
print(inputs_dict)
inputs = []
for name, data in inputs_dict.items():
if type(data) is str:
bytes_data = np.array([data.encode("utf-8")], dtype=np.object_).reshape([1, -1])
input = httpclient.InferInput(name, bytes_data.shape, "BYTES")
elif type(data) is int:
bytes_data = np.array([data], dtype=np.int64).reshape(1, 1)
input = httpclient.InferInput(name, bytes_data.shape, "INT64")
input.set_data_from_numpy(bytes_data)
inputs.append(input)
outputs = [
httpclient.InferRequestedOutput("o_image", binary_data=True),
]
async_request = triton_client.async_infer(
model_name="ensemble",
inputs=inputs,
outputs=outputs,
timeout=60 * 20,
)
result = async_request.get_result()
img = result.as_numpy("o_image")[0]
buff = BytesIO(base64.b64decode(img))
image = Image.open(buff)
return enhanced_prompt, image
def launch_interface():
# Set specific directory for gradio cache/temp files
gradio_temp_dir = "/maindata/data/shared/public/guibin.chen/gradio_cache"
os.makedirs(gradio_temp_dir, exist_ok=True)
os.environ["GRADIO_TEMP_DIR"] = gradio_temp_dir
default_system_prompt = "You are an expert at writing detailed, creative and vivid image generation prompts. Enhance the user's prompt by adding more details and artistic direction while maintaining their original intent."
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="System Prompt", value=default_system_prompt, lines=3),
gr.Textbox(label="User Prompt"),
gr.Number(label="Seed (0 for random)", value=0, precision=0),
gr.Slider(minimum=64, maximum=2048, value=720, step=8, label="Height"),
gr.Slider(minimum=64, maximum=2048, value=1280, step=8, label="Width")
],
outputs=[
gr.Textbox(label="Enhanced Prompt"),
gr.Image(type="pil", label="Generated Image")
],
title="Enhanced Image Generation Interface",
description="Generate images from text prompts with AI enhancement"
)
interface.launch(server_name="0.0.0.0", server_port=7890, share=True)
if __name__ == "__main__":
launch_interface()