lionelgarnier commited on
Commit
412c4ad
·
1 Parent(s): 8538434

add example selection and processing pipeline functions

Browse files
Files changed (1) hide show
  1. app.py +50 -8
app.py CHANGED
@@ -22,7 +22,7 @@ _image_gen_pipeline = None
22
  @spaces.GPU()
23
  def get_image_gen_pipeline():
24
  global _image_gen_pipeline
25
- if _image_gen_pipeline is None:
26
  try:
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  dtype = torch.bfloat16
@@ -42,7 +42,7 @@ def get_image_gen_pipeline():
42
  @spaces.GPU()
43
  def get_text_gen_pipeline():
44
  global _text_gen_pipeline
45
- if _text_gen_pipeline is None:
46
  try:
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
  tokenizer = AutoTokenizer.from_pretrained(
@@ -179,6 +179,36 @@ def preload_models():
179
  print(status)
180
  return success
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def create_interface():
183
  # Preload models if needed
184
  if PRELOAD_MODELS:
@@ -246,16 +276,28 @@ def create_interface():
246
  minimum=1,
247
  maximum=50,
248
  step=1,
249
- value=10,
250
  )
251
 
252
- # Examples section
253
  gr.Examples(
254
  examples=examples,
255
- fn=refine_prompt,
256
- inputs=[prompt],
257
- outputs=[refined_prompt],
258
- cache_examples=True,
 
 
 
 
 
 
 
 
 
 
 
 
259
  )
260
 
261
  # Event handlers
 
22
  @spaces.GPU()
23
  def get_image_gen_pipeline():
24
  global _image_gen_pipeline
25
+ if (_image_gen_pipeline is None):
26
  try:
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  dtype = torch.bfloat16
 
42
  @spaces.GPU()
43
  def get_text_gen_pipeline():
44
  global _text_gen_pipeline
45
+ if (_text_gen_pipeline is None):
46
  try:
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
  tokenizer = AutoTokenizer.from_pretrained(
 
179
  print(status)
180
  return success
181
 
182
+ # Add a new function to handle example selection
183
+ def handle_example_click(example_prompt):
184
+ # Immediately return the example prompt to update the UI
185
+ return example_prompt, "Example selected - click 'Refine prompt with Mistral' to process"
186
+
187
+ # Create a combined function that handles the whole pipeline from example to image
188
+ @spaces.GPU()
189
+ def process_example_pipeline(example_prompt, seed, randomize_seed, width, height, num_inference_steps, progress=gr.Progress()):
190
+ # Step 1: Update status
191
+ progress(0, desc="Starting example processing")
192
+ progress_status = "Selected example: " + example_prompt
193
+
194
+ # Step 2: Refine the prompt
195
+ progress(0.1, desc="Refining prompt with Mistral")
196
+ refined, status = refine_prompt(example_prompt, progress)
197
+
198
+ if not refined:
199
+ return example_prompt, "", None, "Failed to refine prompt: " + status
200
+
201
+ progress(0.5, desc="Prompt refined, generating image")
202
+ progress_status = status
203
+
204
+ # Step 3: Generate the image
205
+ image, image_status = infer(refined, seed, randomize_seed, width, height, num_inference_steps, progress)
206
+
207
+ progress(1.0, desc="Process complete")
208
+ final_status = f"{progress_status} → {image_status}"
209
+
210
+ return example_prompt, refined, image, final_status
211
+
212
  def create_interface():
213
  # Preload models if needed
214
  if PRELOAD_MODELS:
 
276
  minimum=1,
277
  maximum=50,
278
  step=1,
279
+ value=6,
280
  )
281
 
282
+ # Examples section - use the pipeline function for examples
283
  gr.Examples(
284
  examples=examples,
285
+ fn=process_example_pipeline,
286
+ inputs=[
287
+ prompt,
288
+ seed,
289
+ randomize_seed,
290
+ width,
291
+ height,
292
+ num_inference_steps
293
+ ],
294
+ outputs=[
295
+ prompt,
296
+ refined_prompt,
297
+ generated_image,
298
+ error_box
299
+ ],
300
+ cache_examples=True, # Can be cached now
301
  )
302
 
303
  # Event handlers