barreloflube commited on
Commit
20c6eca
·
1 Parent(s): a7c9e18

Refactor progress tracking in generate_image function

Browse files
Files changed (2) hide show
  1. tabs/images/events.py +4 -2
  2. tabs/images/handlers.py +6 -3
tabs/images/events.py CHANGED
@@ -413,9 +413,11 @@ def generate_image(
413
  resize_mode,
414
  scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
415
  image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
416
- refiner, vae
 
417
  ):
418
  try:
 
419
  base_args = {
420
  "model": model,
421
  "prompt": prompt,
@@ -507,7 +509,7 @@ def generate_image(
507
  base_args = BaseReq(**base_args.__dict__)
508
 
509
  return gr.update(
510
- value=gen_img(base_args),
511
  interactive=True
512
  )
513
  except Exception as e:
 
413
  resize_mode,
414
  scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
415
  image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
416
+ refiner, vae,
417
+ progress=gr.Progress(track_tqdm=True)
418
  ):
419
  try:
420
+ progress(0, "Configuring arguments...")
421
  base_args = {
422
  "model": model,
423
  "prompt": prompt,
 
509
  base_args = BaseReq(**base_args.__dict__)
510
 
511
  return gr.update(
512
+ value=gen_img(base_args, progress),
513
  interactive=True
514
  )
515
  except Exception as e:
tabs/images/handlers.py CHANGED
@@ -205,10 +205,12 @@ def cleanup(pipeline, loras = None, embeddings = None):
205
 
206
 
207
  # Gen Function
208
- def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
 
209
  pipeline_args = get_pipe(request)
210
  pipeline = pipeline_args["pipeline"]
211
  try:
 
212
  positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
213
 
214
  # Common Args
@@ -243,15 +245,16 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
243
  args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
244
 
245
  # Generate
 
246
  images = pipeline(**args).images
247
 
248
  # Refiner
249
  if request.refiner:
250
  images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
251
 
 
 
252
  return images
253
  except Exception as e:
254
  cleanup(pipeline, request.loras, request.embeddings)
255
  raise gr.Error(f"Error: {e}")
256
- finally:
257
- cleanup(pipeline, request.loras, request.embeddings)
 
205
 
206
 
207
  # Gen Function
208
+ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Progress(track_tqdm=True)):
209
+ progress(0.1, "Loading Pipeline")
210
  pipeline_args = get_pipe(request)
211
  pipeline = pipeline_args["pipeline"]
212
  try:
213
+ progress(0.5, "Configuring Pipeline")
214
  positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
215
 
216
  # Common Args
 
245
  args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
246
 
247
  # Generate
248
+ progress(0.9, "Generating Images")
249
  images = pipeline(**args).images
250
 
251
  # Refiner
252
  if request.refiner:
253
  images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
254
 
255
+ progress(1.0, "Cleaning Up")
256
+ cleanup(pipeline, request.loras, request.embeddings)
257
  return images
258
  except Exception as e:
259
  cleanup(pipeline, request.loras, request.embeddings)
260
  raise gr.Error(f"Error: {e}")