File size: 7,159 Bytes
5a639eb 57c8777 47f6b10 5a639eb 2efad28 5a639eb 9fd6eb6 5a639eb 2efad28 5a639eb 2efad28 5a639eb 9fd6eb6 5a639eb 57c8777 5a639eb 57c8777 7a858c7 5a639eb 57c8777 5a639eb 7a858c7 5a639eb ee219ba 5a639eb 57c8777 5a639eb 57c8777 5a639eb 57c8777 5a639eb 57c8777 5a639eb ee219ba caf412f 5a639eb 57c8777 5a639eb 57c8777 5a639eb 57c8777 5a639eb ee219ba 5a639eb 57c8777 5a639eb 57c8777 5a639eb 57c8777 5a639eb 57c8777 5a639eb 57c8777 1ab33aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import gradio as gr
import torch
import spaces
from diffusers import Lumina2Pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
# # Set up environment
# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
# Load models
def load_models():
model_name = "X-ART/LeX-Enhancer-full"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipe = Lumina2Pipeline.from_pretrained(
"X-ART/LeX-Lumina",
torch_dtype=torch.bfloat16
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device, torch_dtype)
return model, tokenizer, pipe
model, tokenizer, pipe = load_models()
@spaces.GPU(duration=100)
def generate_enhanced_caption(image_caption, text_caption):
"""Generate enhanced caption using the LeX-Enhancer model"""
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
instruction = """
Below is the simple caption of an image with text. Please deduce the detailed description of the image based on this simple caption. Note: 1. The description should only include visual elements and should not contain any extended meanings. 2. The visual elements should be as rich as possible, such as the main objects in the image, their respective attributes, the spatial relationships between the objects, lighting and shadows, color style, any text in the image and its style, etc. 3. The output description should be a single paragraph and should not be structured. 4. The description should avoid certain situations, such as pure white or black backgrounds, blurry text, excessive rendering of text, or harsh visual styles. 5. The detailed caption should be human readable and fluent. 6. Avoid using vague expressions such as "may be" or "might be"; the generated caption must be in a definitive, narrative tone. 7. Do not use negative sentence structures, such as "there is nothing in the image," etc. The entire caption should directly describe the content of the image. 8. The entire output should be limited to 200 words.
"""
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": instruction + "\nSimple Caption:\n" + combined_caption}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
enhanced_caption = response.split("</think>", -1)[-1].strip(" ").strip("\n")
return combined_caption, enhanced_caption
@spaces.GPU(duration=100)
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
"""Generate image using LeX-Lumina"""
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None
image = pipe(
enhanced_caption,
height=1024,
width=1024,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
cfg_trunc_ratio=1,
cfg_normalization=True,
max_sequence_length=256,
generator=generator,
system_prompt="You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.",
).images[0]
return image
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale):
"""Run the complete pipeline from captions to final image"""
combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
return {
"combined_caption": combined_caption,
"enhanced_caption": enhanced_caption,
"image": image
}
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo")
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina")
with gr.Row():
with gr.Column():
image_caption = gr.Textbox(
lines=2,
label="Image Caption",
placeholder="Describe the visual content of the image",
value="A picture of a group of people gathered in front of a world map"
)
text_caption = gr.Textbox(
lines=2,
label="Text Caption",
placeholder="Describe any text that should appear in the image",
value="\"Communicate\" in purple, \"Execute\" in yellow"
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
minimum=0,
maximum=100000,
value=0,
step=1,
label="Seed (0 for random)"
)
num_inference_steps = gr.Slider(
minimum=20,
maximum=100,
value=30,
step=1,
label="Number of Inference Steps"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=4.0,
step=0.1,
label="Guidance Scale"
)
submit_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image")
combined_caption_box = gr.Textbox(
label="Combined Caption",
interactive=False
)
enhanced_caption_box = gr.Textbox(
label="Enhanced Caption",
interactive=False,
lines=5
)
# Example prompts
examples = [
["A modern office workspace", "\"Innovation\" in bold blue letters at the center"],
["A beach sunset scene", "\"Relax\" in cursive white text in the corner"],
["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"]
]
gr.Examples(
examples=examples,
inputs=[image_caption, text_caption],
label="Example Inputs"
)
submit_btn.click(
fn=run_pipeline,
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale],
outputs=[output_image, combined_caption_box, enhanced_caption_box]
)
if __name__ == "__main__":
demo.launch(debug=True) |