|
import os |
|
import gradio as gr |
|
from gradio_client import Client, handle_file |
|
import torch |
|
import spaces |
|
from diffusers import FluxPipeline |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
if torch.cuda.is_available(): |
|
torch_dtype = torch.bfloat16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_models(): |
|
pipe = FluxPipeline.from_pretrained( |
|
"X-ART/LeX-FLUX", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pipe.to("cuda") |
|
|
|
return pipe |
|
|
|
def prompt_enhance(client, image_caption, text_caption): |
|
combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption") |
|
return combined_caption, enhanced_caption |
|
|
|
|
|
pipe = load_models() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale): |
|
|
|
pipe.enable_model_cpu_offload() |
|
"""Generate image using LeX-FLUX""" |
|
|
|
|
|
|
|
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None |
|
|
|
image = pipe( |
|
enhanced_caption, |
|
height=1024, |
|
width=1024, |
|
guidance_scale=3.5, |
|
output_type="pil", |
|
num_inference_steps=28, |
|
max_sequence_length=512, |
|
generator=torch.Generator("cpu").manual_seed(0) |
|
).images[0] |
|
|
|
print(image) |
|
|
|
|
|
|
|
|
|
return image |
|
|
|
|
|
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer): |
|
"""Run the complete pipeline from captions to final image""" |
|
combined_caption = f"{image_caption}, with the text on it: {text_caption}." |
|
|
|
if enable_enhancer: |
|
|
|
client = Client("stzhao/LeX-Enhancer") |
|
combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption) |
|
print(f"enhanced caption:\n{enhanced_caption}") |
|
else: |
|
enhanced_caption = combined_caption |
|
|
|
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale) |
|
|
|
return image, combined_caption, enhanced_caption |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("# LeX-Enhancer & LeX-FLUX Demo") |
|
gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/") |
|
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-FLUX") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_caption = gr.Textbox( |
|
lines=2, |
|
label="Image Caption", |
|
placeholder="Describe the visual content of the image", |
|
value="A picture of a group of people gathered in front of a world map" |
|
) |
|
text_caption = gr.Textbox( |
|
lines=2, |
|
label="Text Caption", |
|
placeholder="Describe any text that should appear in the image", |
|
value="\"Communicate\" in purple, \"Execute\" in yellow" |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
enable_enhancer = gr.Checkbox( |
|
label="Enable LeX-Enhancer", |
|
value=True, |
|
info="When enabled, the caption will be enhanced before image generation" |
|
) |
|
seed = gr.Slider( |
|
minimum=0, |
|
maximum=100000, |
|
value=0, |
|
step=1, |
|
label="Seed (0 for random)" |
|
) |
|
num_inference_steps = gr.Slider( |
|
minimum=20, |
|
maximum=100, |
|
value=40, |
|
step=1, |
|
label="Number of Inference Steps" |
|
) |
|
guidance_scale = gr.Slider( |
|
minimum=1.0, |
|
maximum=10.0, |
|
value=7.5, |
|
step=0.1, |
|
label="Guidance Scale" |
|
) |
|
|
|
submit_btn = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image") |
|
combined_caption_box = gr.Textbox( |
|
label="Combined Caption", |
|
interactive=False |
|
) |
|
enhanced_caption_box = gr.Textbox( |
|
label="Enhanced Caption" if enable_enhancer.value else "Final Caption", |
|
interactive=False, |
|
lines=5 |
|
) |
|
|
|
|
|
examples = [ |
|
["A modern office workspace", "\"Innovation\" in bold blue letters at the center"], |
|
["A beach sunset scene", "\"Relax\" in cursive white text in the corner"], |
|
["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"] |
|
] |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[image_caption, text_caption], |
|
label="Example Inputs" |
|
) |
|
|
|
|
|
def update_caption_label(enable_enhancer): |
|
return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption") |
|
|
|
enable_enhancer.change( |
|
fn=update_caption_label, |
|
inputs=enable_enhancer, |
|
outputs=enhanced_caption_box |
|
) |
|
|
|
submit_btn.click( |
|
fn=run_pipeline, |
|
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer], |
|
outputs=[output_image, combined_caption_box, enhanced_caption_box] |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |