File size: 6,751 Bytes
5a639eb 57c8777 a5f55f1 57c8777 47f6b10 6624a4b 5a639eb 9fd6eb6 a5f55f1 9e8b0d0 a5f55f1 9e8b0d0 a5f55f1 5a639eb a5f55f1 6624a4b 8bd8972 5a639eb 9fd6eb6 a43380e 5a639eb a5f55f1 57c8777 a5f55f1 5a639eb a5f55f1 5a639eb cbf45f4 c29c340 a43380e 17764e8 80dc987 2092a85 a5f55f1 2092a85 5a639eb 6624a4b 57c8777 5a639eb 6624a4b 57c8777 e435d1a 9195195 5a639eb a5f55f1 9e8b0d0 5a639eb 9bfc50b a5f55f1 9e8b0d0 a5f55f1 6066007 9bfc50b 5a639eb d7941ea 5a639eb 9e8b0d0 80dc987 6bcb844 6066007 5a639eb 57c8777 5a639eb 57c8777 5a639eb 9bfc50b 1032044 9bfc50b 5a639eb 57c8777 5a639eb 7444c68 5a639eb 57c8777 5a639eb 57c8777 7444c68 57c8777 5a639eb 57c8777 5a639eb 9bfc50b 5a639eb 9bfc50b 095edc3 9bfc50b 5a639eb 9e8b0d0 5a639eb 57c8777 9e8b0d0 a5f55f1 57c8777 cbf45f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import gradio as gr
from gradio_client import Client, handle_file
import torch
import spaces
from diffusers import FluxPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
# def set_client_for_session(request: gr.Request):
# x_ip_token = request.headers['x-ip-token']
# # The "gradio/text-to-image" space is a ZeroGPU space
# # return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token})
# return Client("stzhao/LeX-Enhancer")
# Load models
def load_models():
pipe = FluxPipeline.from_pretrained(
"X-ART/LeX-FLUX",
torch_dtype=torch.bfloat16
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to("cuda")
return pipe
def prompt_enhance(client, image_caption, text_caption):
combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption")
return combined_caption, enhanced_caption
pipe = load_models()
# def truncate_caption_by_tokens(caption, max_tokens=256):
# """Truncate the caption to fit within the max token limit"""
# tokens = tokenizer.encode(caption)
# if len(tokens) > max_tokens:
# truncated_tokens = tokens[:max_tokens]
# caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
# print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
# return caption
@spaces.GPU(duration=60)
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
# pipe.to("cuda")
pipe.enable_model_cpu_offload()
"""Generate image using LeX-FLUX"""
# Truncate the caption if it's too long
# enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
image = pipe(
enhanced_caption,
height=1024,
width=1024,
guidance_scale=3.5,
output_type="pil",
num_inference_steps=28,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
print(image)
# pipe.to("cpu")
# torch.cuda.empty_cache()
return image
# @spaces.GPU(duration=130)
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
"""Run the complete pipeline from captions to final image"""
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
if enable_enhancer:
# combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
client = Client("stzhao/LeX-Enhancer")
combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption)
print(f"enhanced caption:\n{enhanced_caption}")
else:
enhanced_caption = combined_caption
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
return image, combined_caption, enhanced_caption
# Gradio interface
with gr.Blocks() as demo:
# client = gr.State()
gr.Markdown("# LeX-Enhancer & LeX-FLUX Demo")
gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-FLUX")
with gr.Row():
with gr.Column():
image_caption = gr.Textbox(
lines=2,
label="Image Caption",
placeholder="Describe the visual content of the image",
value="A picture of a group of people gathered in front of a world map"
)
text_caption = gr.Textbox(
lines=2,
label="Text Caption",
placeholder="Describe any text that should appear in the image",
value="\"Communicate\" in purple, \"Execute\" in yellow"
)
with gr.Accordion("Advanced Settings", open=False):
enable_enhancer = gr.Checkbox(
label="Enable LeX-Enhancer",
value=True,
info="When enabled, the caption will be enhanced before image generation"
)
seed = gr.Slider(
minimum=0,
maximum=100000,
value=0,
step=1,
label="Seed (0 for random)"
)
num_inference_steps = gr.Slider(
minimum=20,
maximum=100,
value=40,
step=1,
label="Number of Inference Steps"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=7.5,
step=0.1,
label="Guidance Scale"
)
submit_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image")
combined_caption_box = gr.Textbox(
label="Combined Caption",
interactive=False
)
enhanced_caption_box = gr.Textbox(
label="Enhanced Caption" if enable_enhancer.value else "Final Caption",
interactive=False,
lines=5
)
# Example prompts
examples = [
["A modern office workspace", "\"Innovation\" in bold blue letters at the center"],
["A beach sunset scene", "\"Relax\" in cursive white text in the corner"],
["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"]
]
gr.Examples(
examples=examples,
inputs=[image_caption, text_caption],
label="Example Inputs"
)
# Update the label of enhanced_caption_box based on checkbox state
def update_caption_label(enable_enhancer):
return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption")
enable_enhancer.change(
fn=update_caption_label,
inputs=enable_enhancer,
outputs=enhanced_caption_box
)
submit_btn.click(
fn=run_pipeline,
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
outputs=[output_image, combined_caption_box, enhanced_caption_box]
)
# demo.load(set_client_for_session, None, client)
if __name__ == "__main__":
demo.launch(debug=True) |