Spaces:
Runtime error
Runtime error
lionelgarnier
commited on
Commit
·
412c4ad
1
Parent(s):
8538434
add example selection and processing pipeline functions
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ _image_gen_pipeline = None
|
|
22 |
@spaces.GPU()
|
23 |
def get_image_gen_pipeline():
|
24 |
global _image_gen_pipeline
|
25 |
-
if _image_gen_pipeline is None:
|
26 |
try:
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
dtype = torch.bfloat16
|
@@ -42,7 +42,7 @@ def get_image_gen_pipeline():
|
|
42 |
@spaces.GPU()
|
43 |
def get_text_gen_pipeline():
|
44 |
global _text_gen_pipeline
|
45 |
-
if _text_gen_pipeline is None:
|
46 |
try:
|
47 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
48 |
tokenizer = AutoTokenizer.from_pretrained(
|
@@ -179,6 +179,36 @@ def preload_models():
|
|
179 |
print(status)
|
180 |
return success
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
def create_interface():
|
183 |
# Preload models if needed
|
184 |
if PRELOAD_MODELS:
|
@@ -246,16 +276,28 @@ def create_interface():
|
|
246 |
minimum=1,
|
247 |
maximum=50,
|
248 |
step=1,
|
249 |
-
value=
|
250 |
)
|
251 |
|
252 |
-
# Examples section
|
253 |
gr.Examples(
|
254 |
examples=examples,
|
255 |
-
fn=
|
256 |
-
inputs=[
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
)
|
260 |
|
261 |
# Event handlers
|
|
|
22 |
@spaces.GPU()
|
23 |
def get_image_gen_pipeline():
|
24 |
global _image_gen_pipeline
|
25 |
+
if (_image_gen_pipeline is None):
|
26 |
try:
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
dtype = torch.bfloat16
|
|
|
42 |
@spaces.GPU()
|
43 |
def get_text_gen_pipeline():
|
44 |
global _text_gen_pipeline
|
45 |
+
if (_text_gen_pipeline is None):
|
46 |
try:
|
47 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
48 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
179 |
print(status)
|
180 |
return success
|
181 |
|
182 |
+
# Add a new function to handle example selection
|
183 |
+
def handle_example_click(example_prompt):
|
184 |
+
# Immediately return the example prompt to update the UI
|
185 |
+
return example_prompt, "Example selected - click 'Refine prompt with Mistral' to process"
|
186 |
+
|
187 |
+
# Create a combined function that handles the whole pipeline from example to image
|
188 |
+
@spaces.GPU()
|
189 |
+
def process_example_pipeline(example_prompt, seed, randomize_seed, width, height, num_inference_steps, progress=gr.Progress()):
|
190 |
+
# Step 1: Update status
|
191 |
+
progress(0, desc="Starting example processing")
|
192 |
+
progress_status = "Selected example: " + example_prompt
|
193 |
+
|
194 |
+
# Step 2: Refine the prompt
|
195 |
+
progress(0.1, desc="Refining prompt with Mistral")
|
196 |
+
refined, status = refine_prompt(example_prompt, progress)
|
197 |
+
|
198 |
+
if not refined:
|
199 |
+
return example_prompt, "", None, "Failed to refine prompt: " + status
|
200 |
+
|
201 |
+
progress(0.5, desc="Prompt refined, generating image")
|
202 |
+
progress_status = status
|
203 |
+
|
204 |
+
# Step 3: Generate the image
|
205 |
+
image, image_status = infer(refined, seed, randomize_seed, width, height, num_inference_steps, progress)
|
206 |
+
|
207 |
+
progress(1.0, desc="Process complete")
|
208 |
+
final_status = f"{progress_status} → {image_status}"
|
209 |
+
|
210 |
+
return example_prompt, refined, image, final_status
|
211 |
+
|
212 |
def create_interface():
|
213 |
# Preload models if needed
|
214 |
if PRELOAD_MODELS:
|
|
|
276 |
minimum=1,
|
277 |
maximum=50,
|
278 |
step=1,
|
279 |
+
value=6,
|
280 |
)
|
281 |
|
282 |
+
# Examples section - use the pipeline function for examples
|
283 |
gr.Examples(
|
284 |
examples=examples,
|
285 |
+
fn=process_example_pipeline,
|
286 |
+
inputs=[
|
287 |
+
prompt,
|
288 |
+
seed,
|
289 |
+
randomize_seed,
|
290 |
+
width,
|
291 |
+
height,
|
292 |
+
num_inference_steps
|
293 |
+
],
|
294 |
+
outputs=[
|
295 |
+
prompt,
|
296 |
+
refined_prompt,
|
297 |
+
generated_image,
|
298 |
+
error_box
|
299 |
+
],
|
300 |
+
cache_examples=True, # Can be cached now
|
301 |
)
|
302 |
|
303 |
# Event handlers
|