LeX-FLUX / app.py
stzhao's picture
Update app.py
9e8b0d0 verified
import os
import gradio as gr
from gradio_client import Client, handle_file
import torch
import spaces
from diffusers import FluxPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
# def set_client_for_session(request: gr.Request):
# x_ip_token = request.headers['x-ip-token']
# # The "gradio/text-to-image" space is a ZeroGPU space
# # return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token})
# return Client("stzhao/LeX-Enhancer")
# Load models
def load_models():
pipe = FluxPipeline.from_pretrained(
"X-ART/LeX-FLUX",
torch_dtype=torch.bfloat16
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to("cuda")
return pipe
def prompt_enhance(client, image_caption, text_caption):
combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption")
return combined_caption, enhanced_caption
pipe = load_models()
# def truncate_caption_by_tokens(caption, max_tokens=256):
# """Truncate the caption to fit within the max token limit"""
# tokens = tokenizer.encode(caption)
# if len(tokens) > max_tokens:
# truncated_tokens = tokens[:max_tokens]
# caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
# print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
# return caption
@spaces.GPU(duration=60)
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
# pipe.to("cuda")
pipe.enable_model_cpu_offload()
"""Generate image using LeX-FLUX"""
# Truncate the caption if it's too long
# enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
image = pipe(
enhanced_caption,
height=1024,
width=1024,
guidance_scale=3.5,
output_type="pil",
num_inference_steps=28,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
print(image)
# pipe.to("cpu")
# torch.cuda.empty_cache()
return image
# @spaces.GPU(duration=130)
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
"""Run the complete pipeline from captions to final image"""
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
if enable_enhancer:
# combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
client = Client("stzhao/LeX-Enhancer")
combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption)
print(f"enhanced caption:\n{enhanced_caption}")
else:
enhanced_caption = combined_caption
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
return image, combined_caption, enhanced_caption
# Gradio interface
with gr.Blocks() as demo:
# client = gr.State()
gr.Markdown("# LeX-Enhancer & LeX-FLUX Demo")
gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-FLUX")
with gr.Row():
with gr.Column():
image_caption = gr.Textbox(
lines=2,
label="Image Caption",
placeholder="Describe the visual content of the image",
value="A picture of a group of people gathered in front of a world map"
)
text_caption = gr.Textbox(
lines=2,
label="Text Caption",
placeholder="Describe any text that should appear in the image",
value="\"Communicate\" in purple, \"Execute\" in yellow"
)
with gr.Accordion("Advanced Settings", open=False):
enable_enhancer = gr.Checkbox(
label="Enable LeX-Enhancer",
value=True,
info="When enabled, the caption will be enhanced before image generation"
)
seed = gr.Slider(
minimum=0,
maximum=100000,
value=0,
step=1,
label="Seed (0 for random)"
)
num_inference_steps = gr.Slider(
minimum=20,
maximum=100,
value=40,
step=1,
label="Number of Inference Steps"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=7.5,
step=0.1,
label="Guidance Scale"
)
submit_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image")
combined_caption_box = gr.Textbox(
label="Combined Caption",
interactive=False
)
enhanced_caption_box = gr.Textbox(
label="Enhanced Caption" if enable_enhancer.value else "Final Caption",
interactive=False,
lines=5
)
# Example prompts
examples = [
["A modern office workspace", "\"Innovation\" in bold blue letters at the center"],
["A beach sunset scene", "\"Relax\" in cursive white text in the corner"],
["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"]
]
gr.Examples(
examples=examples,
inputs=[image_caption, text_caption],
label="Example Inputs"
)
# Update the label of enhanced_caption_box based on checkbox state
def update_caption_label(enable_enhancer):
return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption")
enable_enhancer.change(
fn=update_caption_label,
inputs=enable_enhancer,
outputs=enhanced_caption_box
)
submit_btn.click(
fn=run_pipeline,
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
outputs=[output_image, combined_caption_box, enhanced_caption_box]
)
# demo.load(set_client_for_session, None, client)
if __name__ == "__main__":
demo.launch(debug=True)