File size: 6,751 Bytes
5a639eb
57c8777
a5f55f1
57c8777
47f6b10
6624a4b
5a639eb
 
9fd6eb6
 
 
 
 
a5f55f1
9e8b0d0
 
a5f55f1
9e8b0d0
 
 
a5f55f1
5a639eb
a5f55f1
6624a4b
 
8bd8972
5a639eb
9fd6eb6
a43380e
5a639eb
a5f55f1
57c8777
a5f55f1
 
5a639eb
a5f55f1
 
 
 
 
 
 
 
 
 
 
 
 
5a639eb
cbf45f4
c29c340
a43380e
17764e8
80dc987
2092a85
a5f55f1
2092a85
5a639eb
6624a4b
57c8777
5a639eb
 
 
6624a4b
 
 
 
 
57c8777
e435d1a
 
 
9195195
 
5a639eb
 
 
a5f55f1
9e8b0d0
5a639eb
9bfc50b
 
 
a5f55f1
9e8b0d0
a5f55f1
6066007
9bfc50b
 
 
5a639eb
 
d7941ea
5a639eb
 
 
9e8b0d0
80dc987
6bcb844
6066007
5a639eb
 
 
 
 
 
 
 
57c8777
5a639eb
 
 
 
 
57c8777
5a639eb
 
9bfc50b
 
1032044
9bfc50b
 
5a639eb
 
 
 
 
 
57c8777
5a639eb
 
 
7444c68
5a639eb
 
57c8777
 
5a639eb
57c8777
7444c68
57c8777
5a639eb
57c8777
5a639eb
 
 
 
 
 
 
 
 
 
9bfc50b
5a639eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bfc50b
 
095edc3
9bfc50b
 
 
 
 
 
 
5a639eb
 
9e8b0d0
5a639eb
57c8777
 
9e8b0d0
a5f55f1
57c8777
cbf45f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import gradio as gr
from gradio_client import Client, handle_file
import torch
import spaces
from diffusers import FluxPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer

if torch.cuda.is_available():
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float32


# def set_client_for_session(request: gr.Request):
#     x_ip_token = request.headers['x-ip-token']

#     # The "gradio/text-to-image" space is a ZeroGPU space
#     # return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token})
#     return Client("stzhao/LeX-Enhancer")

# Load models
def load_models():    
    pipe = FluxPipeline.from_pretrained(
        "X-ART/LeX-FLUX", 
        torch_dtype=torch.bfloat16
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe.to("cuda")
    
    return pipe

def prompt_enhance(client, image_caption, text_caption):
    combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption")
    return combined_caption, enhanced_caption
    

pipe = load_models()

# def truncate_caption_by_tokens(caption, max_tokens=256):
#     """Truncate the caption to fit within the max token limit"""
#     tokens = tokenizer.encode(caption)
#     if len(tokens) > max_tokens:
#         truncated_tokens = tokens[:max_tokens]
#         caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
#         print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
#     return caption


@spaces.GPU(duration=60)
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
    # pipe.to("cuda")
    pipe.enable_model_cpu_offload()
    """Generate image using LeX-FLUX"""
    # Truncate the caption if it's too long
    # enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
    
    generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None

    image = pipe(
        enhanced_caption,
        height=1024,
        width=1024,
        guidance_scale=3.5,
        output_type="pil",
        num_inference_steps=28,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]

    print(image)
    
    # pipe.to("cpu")
    # torch.cuda.empty_cache()
    
    return image

# @spaces.GPU(duration=130)
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
    """Run the complete pipeline from captions to final image"""
    combined_caption = f"{image_caption}, with the text on it: {text_caption}."
    
    if enable_enhancer:
        # combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
        client = Client("stzhao/LeX-Enhancer")
        combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption)
        print(f"enhanced caption:\n{enhanced_caption}")
    else:
        enhanced_caption = combined_caption
    
    image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
    
    return  image, combined_caption, enhanced_caption

# Gradio interface
with gr.Blocks() as demo:
    # client = gr.State()
    gr.Markdown("# LeX-Enhancer & LeX-FLUX Demo")
    gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
    gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-FLUX")
    
    with gr.Row():
        with gr.Column():
            image_caption = gr.Textbox(
                lines=2,
                label="Image Caption",
                placeholder="Describe the visual content of the image",
                value="A picture of a group of people gathered in front of a world map"
            )
            text_caption = gr.Textbox(
                lines=2,
                label="Text Caption",
                placeholder="Describe any text that should appear in the image",
                value="\"Communicate\" in purple, \"Execute\" in yellow"
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                enable_enhancer = gr.Checkbox(
                    label="Enable LeX-Enhancer",
                    value=True,
                    info="When enabled, the caption will be enhanced before image generation"
                )
                seed = gr.Slider(
                    minimum=0,
                    maximum=100000,
                    value=0,
                    step=1,
                    label="Seed (0 for random)"
                )
                num_inference_steps = gr.Slider(
                    minimum=20,
                    maximum=100,
                    value=40,
                    step=1,
                    label="Number of Inference Steps"
                )
                guidance_scale = gr.Slider(
                    minimum=1.0,
                    maximum=10.0,
                    value=7.5,
                    step=0.1,
                    label="Guidance Scale"
                )
            
            submit_btn = gr.Button("Generate", variant="primary")
        
        with gr.Column():
            output_image = gr.Image(label="Generated Image")
            combined_caption_box = gr.Textbox(
                label="Combined Caption",
                interactive=False
            )
            enhanced_caption_box = gr.Textbox(
                label="Enhanced Caption" if enable_enhancer.value else "Final Caption",
                interactive=False,
                lines=5
            )
    
    # Example prompts
    examples = [
        ["A modern office workspace", "\"Innovation\" in bold blue letters at the center"],
        ["A beach sunset scene", "\"Relax\" in cursive white text in the corner"],
        ["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"]
    ]
    gr.Examples(
        examples=examples,
        inputs=[image_caption, text_caption],
        label="Example Inputs"
    )
    
    # Update the label of enhanced_caption_box based on checkbox state
    def update_caption_label(enable_enhancer):
        return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption")
    
    enable_enhancer.change(
        fn=update_caption_label,
        inputs=enable_enhancer,
        outputs=enhanced_caption_box
    )
    
    submit_btn.click(
        fn=run_pipeline,
        inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
        outputs=[output_image, combined_caption_box, enhanced_caption_box]
    )

    # demo.load(set_client_for_session, None, client)

if __name__ == "__main__":
    demo.launch(debug=True)