stzhao commited on
Commit
9bfc50b
·
verified ·
1 Parent(s): e435d1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -16,14 +16,14 @@ def load_models():
16
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
- torch_dtype=torch.bfloat16,
20
  # device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
24
  pipe = Lumina2Pipeline.from_pretrained(
25
  "X-ART/LeX-Lumina",
26
- torch_dtype=torch.bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  # pipe.to("cuda")
@@ -107,9 +107,15 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
107
  return image
108
 
109
  @spaces.GPU(duration=100)
110
- def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale):
111
  """Run the complete pipeline from captions to final image"""
112
- combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
 
 
 
 
 
 
113
  image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
114
 
115
  return {
@@ -139,6 +145,11 @@ with gr.Blocks() as demo:
139
  )
140
 
141
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
142
  seed = gr.Slider(
143
  minimum=0,
144
  maximum=100000,
@@ -170,7 +181,7 @@ with gr.Blocks() as demo:
170
  interactive=False
171
  )
172
  enhanced_caption_box = gr.Textbox(
173
- label="Enhanced Caption",
174
  interactive=False,
175
  lines=5
176
  )
@@ -187,9 +198,19 @@ with gr.Blocks() as demo:
187
  label="Example Inputs"
188
  )
189
 
 
 
 
 
 
 
 
 
 
 
190
  submit_btn.click(
191
  fn=run_pipeline,
192
- inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale],
193
  outputs=[output_image, combined_caption_box, enhanced_caption_box]
194
  )
195
 
 
16
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
+ torch_dtype=torch_bfloat16,
20
  # device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
24
  pipe = Lumina2Pipeline.from_pretrained(
25
  "X-ART/LeX-Lumina",
26
+ torch_dtype=torch_bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  # pipe.to("cuda")
 
107
  return image
108
 
109
  @spaces.GPU(duration=100)
110
+ def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
111
  """Run the complete pipeline from captions to final image"""
112
+ combined_caption = f"{image_caption}, with the text on it: {text_caption}."
113
+
114
+ if enable_enhancer:
115
+ combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
116
+ else:
117
+ enhanced_caption = combined_caption
118
+
119
  image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
120
 
121
  return {
 
145
  )
146
 
147
  with gr.Accordion("Advanced Settings", open=False):
148
+ enable_enhancer = gr.Checkbox(
149
+ label="Enable LeX-Enhancer",
150
+ value=False,
151
+ info="When enabled, the caption will be enhanced before image generation"
152
+ )
153
  seed = gr.Slider(
154
  minimum=0,
155
  maximum=100000,
 
181
  interactive=False
182
  )
183
  enhanced_caption_box = gr.Textbox(
184
+ label="Enhanced Caption" if enable_enhancer.value else "Final Caption",
185
  interactive=False,
186
  lines=5
187
  )
 
198
  label="Example Inputs"
199
  )
200
 
201
+ # Update the label of enhanced_caption_box based on checkbox state
202
+ def update_caption_label(enable_enhancer):
203
+ return gr.Textbox.update(label="Enhanced Caption" if enable_enhancer else "Final Caption")
204
+
205
+ enable_enhancer.change(
206
+ fn=update_caption_label,
207
+ inputs=enable_enhancer,
208
+ outputs=enhanced_caption_box
209
+ )
210
+
211
  submit_btn.click(
212
  fn=run_pipeline,
213
+ inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
214
  outputs=[output_image, combined_caption_box, enhanced_caption_box]
215
  )
216