stzhao commited on
Commit
a5f55f1
·
verified ·
1 Parent(s): 01f65c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -59
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import gradio as gr
 
3
  import torch
4
  import spaces
5
  from diffusers import Lumina2Pipeline
@@ -10,17 +11,15 @@ if torch.cuda.is_available():
10
  else:
11
  torch_dtype = torch.float32
12
 
 
 
 
 
 
 
 
13
  # Load models
14
- def load_models():
15
- model_name = "X-ART/LeX-Enhancer-full"
16
-
17
- model = AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
- torch_dtype=torch.bfloat16,
20
- device_map="auto"
21
- )
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
-
24
  pipe = Lumina2Pipeline.from_pretrained(
25
  "X-ART/LeX-Lumina",
26
  torch_dtype=torch.bfloat16
@@ -28,52 +27,24 @@ def load_models():
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  pipe.to("cuda")
30
 
31
- return model, tokenizer, pipe
32
-
33
- model, tokenizer, pipe = load_models()
34
-
35
- def truncate_caption_by_tokens(caption, max_tokens=256):
36
- """Truncate the caption to fit within the max token limit"""
37
- tokens = tokenizer.encode(caption)
38
- if len(tokens) > max_tokens:
39
- truncated_tokens = tokens[:max_tokens]
40
- caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
41
- print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
42
- return caption
43
-
44
- @spaces.GPU(duration=70)
45
- def generate_enhanced_caption(image_caption, text_caption):
46
- # model.to("cuda")
47
- """Generate enhanced caption using the LeX-Enhancer model"""
48
- combined_caption = f"{image_caption}, with the text on it: {text_caption}."
49
- instruction = """
50
- Below is the simple caption of an image with text. Please deduce the detailed description of the image based on this simple caption. Note: 1. The description should only include visual elements and should not contain any extended meanings. 2. The visual elements should be as rich as possible, such as the main objects in the image, their respective attributes, the spatial relationships between the objects, lighting and shadows, color style, any text in the image and its style, etc. 3. The output description should be a single paragraph and should not be structured. 4. The description should avoid certain situations, such as pure white or black backgrounds, blurry text, excessive rendering of text, or harsh visual styles. 5. The detailed caption should be human readable and fluent. 6. Avoid using vague expressions such as "may be" or "might be"; the generated caption must be in a definitive, narrative tone. 7. Do not use negative sentence structures, such as "there is nothing in the image," etc. The entire caption should directly describe the content of the image. 8. The entire output should be limited to 200 words.
51
- """
52
- messages = [
53
- {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
54
- {"role": "user", "content": instruction + "\nSimple Caption:\n" + combined_caption}
55
- ]
56
- text = tokenizer.apply_chat_template(
57
- messages,
58
- tokenize=False,
59
- add_generation_prompt=True
60
- )
61
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
62
 
63
- generated_ids = model.generate(
64
- **model_inputs,
65
- max_new_tokens=1024
66
- )
67
- generated_ids = [
68
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
69
- ]
70
-
71
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
- enhanced_caption = response.split("</think>", -1)[-1].strip(" ").strip("\n")
73
- model.to("cpu")
74
- torch.cuda.empty_cache()
75
-
76
  return combined_caption, enhanced_caption
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @spaces.GPU(duration=60)
79
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
@@ -81,7 +52,7 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
81
  pipe.enable_model_cpu_offload()
82
  """Generate image using LeX-Lumina"""
83
  # Truncate the caption if it's too long
84
- enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
85
 
86
  print(f"enhanced caption:\n{enhanced_caption}")
87
 
@@ -107,13 +78,14 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
107
 
108
  return image
109
 
110
- @spaces.GPU(duration=130)
111
- def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
112
  """Run the complete pipeline from captions to final image"""
113
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
114
 
115
  if enable_enhancer:
116
- combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
 
117
  else:
118
  enhanced_caption = combined_caption
119
 
@@ -123,6 +95,7 @@ def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidanc
123
 
124
  # Gradio interface
125
  with gr.Blocks() as demo:
 
126
  gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo")
127
  gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
128
  gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina")
@@ -208,9 +181,11 @@ with gr.Blocks() as demo:
208
 
209
  submit_btn.click(
210
  fn=run_pipeline,
211
- inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
212
  outputs=[output_image, combined_caption_box, enhanced_caption_box]
213
  )
214
 
 
 
215
  if __name__ == "__main__":
216
  demo.launch(debug=True)
 
1
  import os
2
  import gradio as gr
3
+ from gradio_client import Client, handle_file
4
  import torch
5
  import spaces
6
  from diffusers import Lumina2Pipeline
 
11
  else:
12
  torch_dtype = torch.float32
13
 
14
+
15
+ def set_client_for_session(request: gr.Request):
16
+ x_ip_token = request.headers['x-ip-token']
17
+
18
+ # The "gradio/text-to-image" space is a ZeroGPU space
19
+ return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token})
20
+
21
  # Load models
22
+ def load_models():
 
 
 
 
 
 
 
 
 
23
  pipe = Lumina2Pipeline.from_pretrained(
24
  "X-ART/LeX-Lumina",
25
  torch_dtype=torch.bfloat16
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  pipe.to("cuda")
29
 
30
+ return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def prompt_enhance(client, image_caption, text_caption):
33
+ combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption")
 
 
 
 
 
 
 
 
 
 
 
34
  return combined_caption, enhanced_caption
35
+
36
+
37
+ pipe = load_models()
38
+
39
+ # def truncate_caption_by_tokens(caption, max_tokens=256):
40
+ # """Truncate the caption to fit within the max token limit"""
41
+ # tokens = tokenizer.encode(caption)
42
+ # if len(tokens) > max_tokens:
43
+ # truncated_tokens = tokens[:max_tokens]
44
+ # caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
45
+ # print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens")
46
+ # return caption
47
+
48
 
49
  @spaces.GPU(duration=60)
50
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
 
52
  pipe.enable_model_cpu_offload()
53
  """Generate image using LeX-Lumina"""
54
  # Truncate the caption if it's too long
55
+ # enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256)
56
 
57
  print(f"enhanced caption:\n{enhanced_caption}")
58
 
 
78
 
79
  return image
80
 
81
+ # @spaces.GPU(duration=130)
82
+ def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client):
83
  """Run the complete pipeline from captions to final image"""
84
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
85
 
86
  if enable_enhancer:
87
+ # combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
88
+ combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption)
89
  else:
90
  enhanced_caption = combined_caption
91
 
 
95
 
96
  # Gradio interface
97
  with gr.Blocks() as demo:
98
+ client = gr.State()
99
  gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo")
100
  gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/")
101
  gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina")
 
181
 
182
  submit_btn.click(
183
  fn=run_pipeline,
184
+ inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client],
185
  outputs=[output_image, combined_caption_box, enhanced_caption_box]
186
  )
187
 
188
+ demo.load(set_client_for_session, None, client)
189
+
190
  if __name__ == "__main__":
191
  demo.launch(debug=True)