Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
3 |
import torch
|
4 |
import spaces
|
5 |
from diffusers import Lumina2Pipeline
|
@@ -10,17 +11,15 @@ if torch.cuda.is_available():
|
|
10 |
else:
|
11 |
torch_dtype = torch.float32
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Load models
|
14 |
-
def load_models():
|
15 |
-
model_name = "X-ART/LeX-Enhancer-full"
|
16 |
-
|
17 |
-
model = AutoModelForCausalLM.from_pretrained(
|
18 |
-
model_name,
|
19 |
-
torch_dtype=torch.bfloat16,
|
20 |
-
device_map="auto"
|
21 |
-
)
|
22 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
-
|
24 |
pipe = Lumina2Pipeline.from_pretrained(
|
25 |
"X-ART/LeX-Lumina",
|
26 |
torch_dtype=torch.bfloat16
|
@@ -28,52 +27,24 @@ def load_models():
|
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
pipe.to("cuda")
|
30 |
|
31 |
-
return
|
32 |
-
|
33 |
-
model, tokenizer, pipe = load_models()
|
34 |
-
|
35 |
-
def truncate_caption_by_tokens(caption, max_tokens=256):
|
36 |
-
"""Truncate the caption to fit within the max token limit"""
|
37 |
-
tokens = tokenizer.encode(caption)
|
38 |
-
if len(tokens) > max_tokens:
|
39 |
-
truncated_tokens = tokens[:max_tokens]
|
40 |
-
caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
|
41 |
-
print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
|
42 |
-
return caption
|
43 |
-
|
44 |
-
@spaces.GPU(duration=70)
|
45 |
-
def generate_enhanced_caption(image_caption, text_caption):
|
46 |
-
# model.to("cuda")
|
47 |
-
"""Generate enhanced caption using the LeX-Enhancer model"""
|
48 |
-
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
|
49 |
-
instruction = """
|
50 |
-
Below is the simple caption of an image with text. Please deduce the detailed description of the image based on this simple caption. Note: 1. The description should only include visual elements and should not contain any extended meanings. 2. The visual elements should be as rich as possible, such as the main objects in the image, their respective attributes, the spatial relationships between the objects, lighting and shadows, color style, any text in the image and its style, etc. 3. The output description should be a single paragraph and should not be structured. 4. The description should avoid certain situations, such as pure white or black backgrounds, blurry text, excessive rendering of text, or harsh visual styles. 5. The detailed caption should be human readable and fluent. 6. Avoid using vague expressions such as "may be" or "might be"; the generated caption must be in a definitive, narrative tone. 7. Do not use negative sentence structures, such as "there is nothing in the image," etc. The entire caption should directly describe the content of the image. 8. The entire output should be limited to 200 words.
|
51 |
-
"""
|
52 |
-
messages = [
|
53 |
-
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
54 |
-
{"role": "user", "content": instruction + "\nSimple Caption:\n" + combined_caption}
|
55 |
-
]
|
56 |
-
text = tokenizer.apply_chat_template(
|
57 |
-
messages,
|
58 |
-
tokenize=False,
|
59 |
-
add_generation_prompt=True
|
60 |
-
)
|
61 |
-
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
max_new_tokens=1024
|
66 |
-
)
|
67 |
-
generated_ids = [
|
68 |
-
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
69 |
-
]
|
70 |
-
|
71 |
-
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
72 |
-
enhanced_caption = response.split("</think>", -1)[-1].strip(" ").strip("\n")
|
73 |
-
model.to("cpu")
|
74 |
-
torch.cuda.empty_cache()
|
75 |
-
|
76 |
return combined_caption, enhanced_caption
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
@spaces.GPU(duration=60)
|
79 |
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
|
@@ -81,7 +52,7 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
|
|
81 |
pipe.enable_model_cpu_offload()
|
82 |
"""Generate image using LeX-Lumina"""
|
83 |
# Truncate the caption if it's too long
|
84 |
-
enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
|
85 |
|
86 |
print(f"enhanced caption:\n{enhanced_caption}")
|
87 |
|
@@ -107,13 +78,14 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
|
|
107 |
|
108 |
return image
|
109 |
|
110 |
-
@spaces.GPU(duration=130)
|
111 |
-
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
|
112 |
"""Run the complete pipeline from captions to final image"""
|
113 |
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
|
114 |
|
115 |
if enable_enhancer:
|
116 |
-
combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
|
|
|
117 |
else:
|
118 |
enhanced_caption = combined_caption
|
119 |
|
@@ -123,6 +95,7 @@ def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidanc
|
|
123 |
|
124 |
# Gradio interface
|
125 |
with gr.Blocks() as demo:
|
|
|
126 |
gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo")
|
127 |
gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
|
128 |
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina")
|
@@ -208,9 +181,11 @@ with gr.Blocks() as demo:
|
|
208 |
|
209 |
submit_btn.click(
|
210 |
fn=run_pipeline,
|
211 |
-
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
|
212 |
outputs=[output_image, combined_caption_box, enhanced_caption_box]
|
213 |
)
|
214 |
|
|
|
|
|
215 |
if __name__ == "__main__":
|
216 |
demo.launch(debug=True)
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
from gradio_client import Client, handle_file
|
4 |
import torch
|
5 |
import spaces
|
6 |
from diffusers import Lumina2Pipeline
|
|
|
11 |
else:
|
12 |
torch_dtype = torch.float32
|
13 |
|
14 |
+
|
15 |
+
def set_client_for_session(request: gr.Request):
|
16 |
+
x_ip_token = request.headers['x-ip-token']
|
17 |
+
|
18 |
+
# The "gradio/text-to-image" space is a ZeroGPU space
|
19 |
+
return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token})
|
20 |
+
|
21 |
# Load models
|
22 |
+
def load_models():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
pipe = Lumina2Pipeline.from_pretrained(
|
24 |
"X-ART/LeX-Lumina",
|
25 |
torch_dtype=torch.bfloat16
|
|
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
pipe.to("cuda")
|
29 |
|
30 |
+
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
def prompt_enhance(client, image_caption, text_caption):
|
33 |
+
combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return combined_caption, enhanced_caption
|
35 |
+
|
36 |
+
|
37 |
+
pipe = load_models()
|
38 |
+
|
39 |
+
# def truncate_caption_by_tokens(caption, max_tokens=256):
|
40 |
+
# """Truncate the caption to fit within the max token limit"""
|
41 |
+
# tokens = tokenizer.encode(caption)
|
42 |
+
# if len(tokens) > max_tokens:
|
43 |
+
# truncated_tokens = tokens[:max_tokens]
|
44 |
+
# caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
|
45 |
+
# print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
|
46 |
+
# return caption
|
47 |
+
|
48 |
|
49 |
@spaces.GPU(duration=60)
|
50 |
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
|
|
|
52 |
pipe.enable_model_cpu_offload()
|
53 |
"""Generate image using LeX-Lumina"""
|
54 |
# Truncate the caption if it's too long
|
55 |
+
# enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
|
56 |
|
57 |
print(f"enhanced caption:\n{enhanced_caption}")
|
58 |
|
|
|
78 |
|
79 |
return image
|
80 |
|
81 |
+
# @spaces.GPU(duration=130)
|
82 |
+
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client):
|
83 |
"""Run the complete pipeline from captions to final image"""
|
84 |
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
|
85 |
|
86 |
if enable_enhancer:
|
87 |
+
# combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
|
88 |
+
combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption)
|
89 |
else:
|
90 |
enhanced_caption = combined_caption
|
91 |
|
|
|
95 |
|
96 |
# Gradio interface
|
97 |
with gr.Blocks() as demo:
|
98 |
+
client = gr.State()
|
99 |
gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo")
|
100 |
gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
|
101 |
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina")
|
|
|
181 |
|
182 |
submit_btn.click(
|
183 |
fn=run_pipeline,
|
184 |
+
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client],
|
185 |
outputs=[output_image, combined_caption_box, enhanced_caption_box]
|
186 |
)
|
187 |
|
188 |
+
demo.load(set_client_for_session, None, client)
|
189 |
+
|
190 |
if __name__ == "__main__":
|
191 |
demo.launch(debug=True)
|