barreloflube commited on
Commit
bf49323
·
1 Parent(s): 4ceee35

Refactor image_tab function to update generator initialization and remove device parameter

Browse files
Files changed (1) hide show
  1. tabs/images/handlers.py +3 -2
tabs/images/handlers.py CHANGED
@@ -63,6 +63,7 @@ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
63
  pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
64
  elif request.custom_addons:
65
  ...
 
66
 
67
  # Enable or Disable Vae
68
  if request.vae:
@@ -213,7 +214,7 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
213
  progress(0.3, "Getting Prompt Embeddings")
214
  # Get Prompt Embeddings
215
  if isinstance(pipeline, flux_pipes):
216
- positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt, device=device)
217
  elif isinstance(pipeline, sd_pipes):
218
  positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt)
219
 
@@ -227,7 +228,7 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
227
  'num_images_per_prompt': request.num_images_per_prompt,
228
  'num_inference_steps': request.num_inference_steps,
229
  'guidance_scale': request.guidance_scale,
230
- 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
231
  }
232
 
233
  if isinstance(pipeline, sd_pipes):
 
63
  pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
64
  elif request.custom_addons:
65
  ...
66
+ gr.Info(f"Pipeline Mode: {type(pipe_args['pipeline'])}")
67
 
68
  # Enable or Disable Vae
69
  if request.vae:
 
214
  progress(0.3, "Getting Prompt Embeddings")
215
  # Get Prompt Embeddings
216
  if isinstance(pipeline, flux_pipes):
217
+ positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt)
218
  elif isinstance(pipeline, sd_pipes):
219
  positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt)
220
 
 
228
  'num_images_per_prompt': request.num_images_per_prompt,
229
  'num_inference_steps': request.num_inference_steps,
230
  'guidance_scale': request.guidance_scale,
231
+ 'generator': [torch.Generator().manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator().manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
232
  }
233
 
234
  if isinstance(pipeline, sd_pipes):