Update app.py
Browse files
app.py
CHANGED
@@ -16,14 +16,14 @@ def load_models():
|
|
16 |
|
17 |
model = AutoModelForCausalLM.from_pretrained(
|
18 |
model_name,
|
19 |
-
torch_dtype=
|
20 |
# device_map="auto"
|
21 |
)
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
|
24 |
pipe = Lumina2Pipeline.from_pretrained(
|
25 |
"X-ART/LeX-Lumina",
|
26 |
-
torch_dtype=
|
27 |
)
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
# pipe.to("cuda")
|
@@ -107,9 +107,15 @@ def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
|
|
107 |
return image
|
108 |
|
109 |
@spaces.GPU(duration=100)
|
110 |
-
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale):
|
111 |
"""Run the complete pipeline from captions to final image"""
|
112 |
-
combined_caption
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
|
114 |
|
115 |
return {
|
@@ -139,6 +145,11 @@ with gr.Blocks() as demo:
|
|
139 |
)
|
140 |
|
141 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
142 |
seed = gr.Slider(
|
143 |
minimum=0,
|
144 |
maximum=100000,
|
@@ -170,7 +181,7 @@ with gr.Blocks() as demo:
|
|
170 |
interactive=False
|
171 |
)
|
172 |
enhanced_caption_box = gr.Textbox(
|
173 |
-
label="Enhanced Caption",
|
174 |
interactive=False,
|
175 |
lines=5
|
176 |
)
|
@@ -187,9 +198,19 @@ with gr.Blocks() as demo:
|
|
187 |
label="Example Inputs"
|
188 |
)
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
submit_btn.click(
|
191 |
fn=run_pipeline,
|
192 |
-
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale],
|
193 |
outputs=[output_image, combined_caption_box, enhanced_caption_box]
|
194 |
)
|
195 |
|
|
|
16 |
|
17 |
model = AutoModelForCausalLM.from_pretrained(
|
18 |
model_name,
|
19 |
+
torch_dtype=torch_bfloat16,
|
20 |
# device_map="auto"
|
21 |
)
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
|
24 |
pipe = Lumina2Pipeline.from_pretrained(
|
25 |
"X-ART/LeX-Lumina",
|
26 |
+
torch_dtype=torch_bfloat16
|
27 |
)
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
# pipe.to("cuda")
|
|
|
107 |
return image
|
108 |
|
109 |
@spaces.GPU(duration=100)
|
110 |
+
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer):
|
111 |
"""Run the complete pipeline from captions to final image"""
|
112 |
+
combined_caption = f"{image_caption}, with the text on it: {text_caption}."
|
113 |
+
|
114 |
+
if enable_enhancer:
|
115 |
+
combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption)
|
116 |
+
else:
|
117 |
+
enhanced_caption = combined_caption
|
118 |
+
|
119 |
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale)
|
120 |
|
121 |
return {
|
|
|
145 |
)
|
146 |
|
147 |
with gr.Accordion("Advanced Settings", open=False):
|
148 |
+
enable_enhancer = gr.Checkbox(
|
149 |
+
label="Enable LeX-Enhancer",
|
150 |
+
value=False,
|
151 |
+
info="When enabled, the caption will be enhanced before image generation"
|
152 |
+
)
|
153 |
seed = gr.Slider(
|
154 |
minimum=0,
|
155 |
maximum=100000,
|
|
|
181 |
interactive=False
|
182 |
)
|
183 |
enhanced_caption_box = gr.Textbox(
|
184 |
+
label="Enhanced Caption" if enable_enhancer.value else "Final Caption",
|
185 |
interactive=False,
|
186 |
lines=5
|
187 |
)
|
|
|
198 |
label="Example Inputs"
|
199 |
)
|
200 |
|
201 |
+
# Update the label of enhanced_caption_box based on checkbox state
|
202 |
+
def update_caption_label(enable_enhancer):
|
203 |
+
return gr.Textbox.update(label="Enhanced Caption" if enable_enhancer else "Final Caption")
|
204 |
+
|
205 |
+
enable_enhancer.change(
|
206 |
+
fn=update_caption_label,
|
207 |
+
inputs=enable_enhancer,
|
208 |
+
outputs=enhanced_caption_box
|
209 |
+
)
|
210 |
+
|
211 |
submit_btn.click(
|
212 |
fn=run_pipeline,
|
213 |
+
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer],
|
214 |
outputs=[output_image, combined_caption_box, enhanced_caption_box]
|
215 |
)
|
216 |
|