Commit
·
20c6eca
1
Parent(s):
a7c9e18
Refactor progress tracking in generate_image function
Browse files- tabs/images/events.py +4 -2
- tabs/images/handlers.py +6 -3
tabs/images/events.py
CHANGED
@@ -413,9 +413,11 @@ def generate_image(
|
|
413 |
resize_mode,
|
414 |
scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
|
415 |
image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
|
416 |
-
refiner, vae
|
|
|
417 |
):
|
418 |
try:
|
|
|
419 |
base_args = {
|
420 |
"model": model,
|
421 |
"prompt": prompt,
|
@@ -507,7 +509,7 @@ def generate_image(
|
|
507 |
base_args = BaseReq(**base_args.__dict__)
|
508 |
|
509 |
return gr.update(
|
510 |
-
value=gen_img(base_args),
|
511 |
interactive=True
|
512 |
)
|
513 |
except Exception as e:
|
|
|
413 |
resize_mode,
|
414 |
scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
|
415 |
image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
|
416 |
+
refiner, vae,
|
417 |
+
progress=gr.Progress(track_tqdm=True)
|
418 |
):
|
419 |
try:
|
420 |
+
progress(0, "Configuring arguments...")
|
421 |
base_args = {
|
422 |
"model": model,
|
423 |
"prompt": prompt,
|
|
|
509 |
base_args = BaseReq(**base_args.__dict__)
|
510 |
|
511 |
return gr.update(
|
512 |
+
value=gen_img(base_args, progress),
|
513 |
interactive=True
|
514 |
)
|
515 |
except Exception as e:
|
tabs/images/handlers.py
CHANGED
@@ -205,10 +205,12 @@ def cleanup(pipeline, loras = None, embeddings = None):
|
|
205 |
|
206 |
|
207 |
# Gen Function
|
208 |
-
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
|
|
|
209 |
pipeline_args = get_pipe(request)
|
210 |
pipeline = pipeline_args["pipeline"]
|
211 |
try:
|
|
|
212 |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
|
213 |
|
214 |
# Common Args
|
@@ -243,15 +245,16 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
|
|
243 |
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
|
244 |
|
245 |
# Generate
|
|
|
246 |
images = pipeline(**args).images
|
247 |
|
248 |
# Refiner
|
249 |
if request.refiner:
|
250 |
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
|
251 |
|
|
|
|
|
252 |
return images
|
253 |
except Exception as e:
|
254 |
cleanup(pipeline, request.loras, request.embeddings)
|
255 |
raise gr.Error(f"Error: {e}")
|
256 |
-
finally:
|
257 |
-
cleanup(pipeline, request.loras, request.embeddings)
|
|
|
205 |
|
206 |
|
207 |
# Gen Function
|
208 |
+
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Progress(track_tqdm=True)):
|
209 |
+
progress(0.1, "Loading Pipeline")
|
210 |
pipeline_args = get_pipe(request)
|
211 |
pipeline = pipeline_args["pipeline"]
|
212 |
try:
|
213 |
+
progress(0.5, "Configuring Pipeline")
|
214 |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
|
215 |
|
216 |
# Common Args
|
|
|
245 |
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
|
246 |
|
247 |
# Generate
|
248 |
+
progress(0.9, "Generating Images")
|
249 |
images = pipeline(**args).images
|
250 |
|
251 |
# Refiner
|
252 |
if request.refiner:
|
253 |
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
|
254 |
|
255 |
+
progress(1.0, "Cleaning Up")
|
256 |
+
cleanup(pipeline, request.loras, request.embeddings)
|
257 |
return images
|
258 |
except Exception as e:
|
259 |
cleanup(pipeline, request.loras, request.embeddings)
|
260 |
raise gr.Error(f"Error: {e}")
|
|
|
|