import os import gradio as gr from gradio_client import Client, handle_file import torch import spaces from diffusers import FluxPipeline from transformers import AutoModelForCausalLM, AutoTokenizer if torch.cuda.is_available(): torch_dtype = torch.bfloat16 else: torch_dtype = torch.float32 # def set_client_for_session(request: gr.Request): # x_ip_token = request.headers['x-ip-token'] # # The "gradio/text-to-image" space is a ZeroGPU space # # return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token}) # return Client("stzhao/LeX-Enhancer") # Load models def load_models(): pipe = FluxPipeline.from_pretrained( "X-ART/LeX-FLUX", torch_dtype=torch.bfloat16 ) device = "cuda" if torch.cuda.is_available() else "cpu" pipe.to("cuda") return pipe def prompt_enhance(client, image_caption, text_caption): combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption") return combined_caption, enhanced_caption pipe = load_models() # def truncate_caption_by_tokens(caption, max_tokens=256): # """Truncate the caption to fit within the max token limit""" # tokens = tokenizer.encode(caption) # if len(tokens) > max_tokens: # truncated_tokens = tokens[:max_tokens] # caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True) # print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens") # return caption @spaces.GPU(duration=60) def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale): # pipe.to("cuda") pipe.enable_model_cpu_offload() """Generate image using LeX-FLUX""" # Truncate the caption if it's too long # enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256) generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None image = pipe( enhanced_caption, height=1024, width=1024, guidance_scale=3.5, output_type="pil", num_inference_steps=28, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0) ).images[0] print(image) # pipe.to("cpu") # torch.cuda.empty_cache() return image # @spaces.GPU(duration=130) def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer): """Run the complete pipeline from captions to final image""" combined_caption = f"{image_caption}, with the text on it: {text_caption}." if enable_enhancer: # combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption) client = Client("stzhao/LeX-Enhancer") combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption) print(f"enhanced caption:\n{enhanced_caption}") else: enhanced_caption = combined_caption image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale) return image, combined_caption, enhanced_caption # Gradio interface with gr.Blocks() as demo: # client = gr.State() gr.Markdown("# LeX-Enhancer & LeX-FLUX Demo") gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/") gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-FLUX") with gr.Row(): with gr.Column(): image_caption = gr.Textbox( lines=2, label="Image Caption", placeholder="Describe the visual content of the image", value="A picture of a group of people gathered in front of a world map" ) text_caption = gr.Textbox( lines=2, label="Text Caption", placeholder="Describe any text that should appear in the image", value="\"Communicate\" in purple, \"Execute\" in yellow" ) with gr.Accordion("Advanced Settings", open=False): enable_enhancer = gr.Checkbox( label="Enable LeX-Enhancer", value=True, info="When enabled, the caption will be enhanced before image generation" ) seed = gr.Slider( minimum=0, maximum=100000, value=0, step=1, label="Seed (0 for random)" ) num_inference_steps = gr.Slider( minimum=20, maximum=100, value=40, step=1, label="Number of Inference Steps" ) guidance_scale = gr.Slider( minimum=1.0, maximum=10.0, value=7.5, step=0.1, label="Guidance Scale" ) submit_btn = gr.Button("Generate", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated Image") combined_caption_box = gr.Textbox( label="Combined Caption", interactive=False ) enhanced_caption_box = gr.Textbox( label="Enhanced Caption" if enable_enhancer.value else "Final Caption", interactive=False, lines=5 ) # Example prompts examples = [ ["A modern office workspace", "\"Innovation\" in bold blue letters at the center"], ["A beach sunset scene", "\"Relax\" in cursive white text in the corner"], ["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"] ] gr.Examples( examples=examples, inputs=[image_caption, text_caption], label="Example Inputs" ) # Update the label of enhanced_caption_box based on checkbox state def update_caption_label(enable_enhancer): return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption") enable_enhancer.change( fn=update_caption_label, inputs=enable_enhancer, outputs=enhanced_caption_box ) submit_btn.click( fn=run_pipeline, inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer], outputs=[output_image, combined_caption_box, enhanced_caption_box] ) # demo.load(set_client_for_session, None, client) if __name__ == "__main__": demo.launch(debug=True)