Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update pipeline.py
Browse files- pipeline.py +218 -67
    	
        pipeline.py
    CHANGED
    
    | @@ -223,27 +223,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 223 | 
             
                    clip_skip: Optional[int] = None,
         | 
| 224 | 
             
                    max_sequence_length: int = 512,
         | 
| 225 | 
             
                    lora_scale: Optional[float] = None,
         | 
| 226 | 
            -
                ):
         | 
| 227 | 
            -
                    r"""
         | 
| 228 | 
            -
             | 
| 229 | 
            -
                    Args:
         | 
| 230 | 
            -
                        prompt (`str` or `List[str]`, *optional*):     
         | 
| 231 | 
            -
                        prompt_2 (`str` or `List[str]`, *optional*):
         | 
| 232 | 
            -
                            The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
         | 
| 233 | 
            -
                            used in all text-encoders
         | 
| 234 | 
            -
                        device: (`torch.device`):
         | 
| 235 | 
            -
                            torch device
         | 
| 236 | 
            -
                        num_images_per_prompt (`int`):
         | 
| 237 | 
            -
                            number of images that should be generated per prompt
         | 
| 238 | 
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 239 | 
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 240 | 
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 241 | 
            -
                        pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 242 | 
            -
                            Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
         | 
| 243 | 
            -
                            If not provided, pooled text embeddings will be generated from `prompt` input argument.
         | 
| 244 | 
            -
                        lora_scale (`float`, *optional*):
         | 
| 245 | 
            -
                            A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
         | 
| 246 | 
            -
                    """
         | 
| 247 | 
             
                    device = device or self._execution_device
         | 
| 248 |  | 
| 249 | 
             
                    if device is None:
         | 
| @@ -297,7 +277,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 297 | 
             
                            batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
         | 
| 298 | 
             
                        )
         | 
| 299 |  | 
| 300 | 
            -
             | 
| 301 | 
             
                            raise TypeError(
         | 
| 302 | 
             
                                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 303 | 
             
                                f" {type(prompt)}."
         | 
| @@ -309,29 +289,29 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 309 | 
             
                                " the batch size of `prompt`."
         | 
| 310 | 
             
                            )
         | 
| 311 |  | 
| 312 | 
            -
             | 
| 313 | 
            -
             | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 316 | 
            -
             | 
| 317 | 
             
                        )
         | 
| 318 | 
            -
             | 
| 319 |  | 
| 320 | 
            -
             | 
| 321 | 
            -
             | 
| 322 | 
            -
             | 
| 323 | 
            -
             | 
| 324 | 
            -
             | 
| 325 | 
             
                        )
         | 
| 326 |  | 
| 327 | 
            -
             | 
| 328 | 
             
                            negative_clip_prompt_embeds,
         | 
| 329 | 
             
                            (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
         | 
| 330 | 
             
                        )
         | 
| 331 |  | 
| 332 | 
            -
             | 
| 333 | 
            -
             | 
| 334 | 
            -
             | 
| 335 | 
             
                        )
         | 
| 336 |  | 
| 337 | 
             
                    if self.text_encoder is not None:
         | 
| @@ -343,26 +323,8 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 343 | 
             
                    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
         | 
| 344 |  | 
| 345 | 
             
                    return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
         | 
| 346 | 
            -
             | 
| 347 | 
            -
             | 
| 348 | 
            -
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 349 | 
            -
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 350 | 
            -
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 351 | 
            -
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 352 | 
            -
                    # and should be between [0, 1]
         | 
| 353 | 
            -
             | 
| 354 | 
            -
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 355 | 
            -
                    extra_step_kwargs = {}
         | 
| 356 | 
            -
                    if accepts_eta:
         | 
| 357 | 
            -
                        extra_step_kwargs["eta"] = eta
         | 
| 358 | 
            -
             | 
| 359 | 
            -
                    # check if the scheduler accepts generator
         | 
| 360 | 
            -
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 361 | 
            -
                    if accepts_generator:
         | 
| 362 | 
            -
                        extra_step_kwargs["generator"] = generator
         | 
| 363 | 
            -
                    return extra_step_kwargs
         | 
| 364 | 
            -
             | 
| 365 | 
            -
                def check_inputs(
         | 
| 366 | 
             
                    self,
         | 
| 367 | 
             
                    prompt,
         | 
| 368 | 
             
                    prompt_2,
         | 
| @@ -464,6 +426,23 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 464 | 
             
                    latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
         | 
| 465 |  | 
| 466 | 
             
                    return latents
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 467 |  | 
| 468 | 
             
                def enable_vae_slicing(self):
         | 
| 469 | 
             
                    r"""
         | 
| @@ -546,7 +525,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 546 | 
             
                @property
         | 
| 547 | 
             
                def interrupt(self):
         | 
| 548 | 
             
                    return self._interrupt
         | 
| 549 | 
            -
             | 
| 550 | 
             
                @torch.no_grad()
         | 
| 551 | 
             
                @torch.inference_mode()
         | 
| 552 | 
             
                def generate_image(
         | 
| @@ -652,6 +631,178 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 652 | 
             
                    # Handle guidance
         | 
| 653 | 
             
                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
         | 
| 654 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 655 | 
             
                    # 6. Denoising loop
         | 
| 656 | 
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 657 | 
             
                        for i, t in enumerate(timesteps):
         | 
| @@ -694,18 +845,18 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 694 | 
             
                         # Yield intermediate result
         | 
| 695 | 
             
                        torch.cuda.empty_cache()
         | 
| 696 |  | 
| 697 | 
            -
             | 
| 698 | 
            -
             | 
| 699 | 
             
                                # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
         | 
| 700 | 
            -
             | 
| 701 |  | 
| 702 | 
            -
             | 
| 703 | 
            -
             | 
| 704 | 
             
                            for k in callback_on_step_end_tensor_inputs:
         | 
| 705 | 
             
                                callback_kwargs[k] = locals()[k]
         | 
| 706 | 
             
                            callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         | 
| 707 |  | 
| 708 | 
            -
             | 
| 709 | 
             
                            prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         | 
| 710 | 
             
                            negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
         | 
| 711 | 
             
                            negative_pooled_prompt_embeds = callback_outputs.pop(
         | 
| @@ -713,10 +864,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile | |
| 713 | 
             
                            )
         | 
| 714 |  | 
| 715 | 
             
                        # call the callback, if provided
         | 
| 716 | 
            -
             | 
| 717 | 
             
                            progress_bar.update()
         | 
| 718 | 
            -
             | 
| 719 | 
            -
                     | 
| 720 | 
             
                    return self._decode_latents_to_image(latents, height, width, output_type)
         | 
| 721 | 
             
                    self.maybe_free_model_hooks()
         | 
| 722 | 
             
                    torch.cuda.empty_cache()
         | 
|  | |
| 223 | 
             
                    clip_skip: Optional[int] = None,
         | 
| 224 | 
             
                    max_sequence_length: int = 512,
         | 
| 225 | 
             
                    lora_scale: Optional[float] = None,
         | 
| 226 | 
            +
                ): 
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 227 | 
             
                    device = device or self._execution_device
         | 
| 228 |  | 
| 229 | 
             
                    if device is None:
         | 
|  | |
| 277 | 
             
                            batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
         | 
| 278 | 
             
                        )
         | 
| 279 |  | 
| 280 | 
            +
                    if prompt is not None and type(prompt) is not type(negative_prompt):
         | 
| 281 | 
             
                            raise TypeError(
         | 
| 282 | 
             
                                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 283 | 
             
                                f" {type(prompt)}."
         | 
|  | |
| 289 | 
             
                                " the batch size of `prompt`."
         | 
| 290 | 
             
                            )
         | 
| 291 |  | 
| 292 | 
            +
                    negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
         | 
| 293 | 
            +
                        negative_prompt,
         | 
| 294 | 
            +
                        device=device,
         | 
| 295 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 296 | 
            +
                        clip_skip=None,
         | 
| 297 | 
             
                        )
         | 
| 298 | 
            +
                    negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
         | 
| 299 |  | 
| 300 | 
            +
                    t5_negative_prompt_embed = self._get_t5_prompt_embeds(
         | 
| 301 | 
            +
                        prompt=negative_prompt_2,
         | 
| 302 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 303 | 
            +
                        max_sequence_length=max_sequence_length,
         | 
| 304 | 
            +
                        device=device,
         | 
| 305 | 
             
                        )
         | 
| 306 |  | 
| 307 | 
            +
                    negative_clip_prompt_embeds = torch.nn.functional.pad(
         | 
| 308 | 
             
                            negative_clip_prompt_embeds,
         | 
| 309 | 
             
                            (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
         | 
| 310 | 
             
                        )
         | 
| 311 |  | 
| 312 | 
            +
                    negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
         | 
| 313 | 
            +
                    negative_pooled_prompt_embeds = torch.cat(
         | 
| 314 | 
            +
                        [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
         | 
| 315 | 
             
                        )
         | 
| 316 |  | 
| 317 | 
             
                    if self.text_encoder is not None:
         | 
|  | |
| 323 | 
             
                    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
         | 
| 324 |  | 
| 325 | 
             
                    return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
         | 
| 326 | 
            +
                    
         | 
| 327 | 
            +
                    def check_inputs(
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 328 | 
             
                    self,
         | 
| 329 | 
             
                    prompt,
         | 
| 330 | 
             
                    prompt_2,
         | 
|  | |
| 426 | 
             
                    latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
         | 
| 427 |  | 
| 428 | 
             
                    return latents
         | 
| 429 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 430 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 431 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 432 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 433 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 434 | 
            +
                    # and should be between [0, 1]
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 437 | 
            +
                    extra_step_kwargs = {}
         | 
| 438 | 
            +
                    if accepts_eta:
         | 
| 439 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    # check if the scheduler accepts generator
         | 
| 442 | 
            +
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 443 | 
            +
                    if accepts_generator:
         | 
| 444 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 445 | 
            +
                    return extra_step_kwargs
         | 
| 446 |  | 
| 447 | 
             
                def enable_vae_slicing(self):
         | 
| 448 | 
             
                    r"""
         | 
|  | |
| 525 | 
             
                @property
         | 
| 526 | 
             
                def interrupt(self):
         | 
| 527 | 
             
                    return self._interrupt
         | 
| 528 | 
            +
                    
         | 
| 529 | 
             
                @torch.no_grad()
         | 
| 530 | 
             
                @torch.inference_mode()
         | 
| 531 | 
             
                def generate_image(
         | 
|  | |
| 631 | 
             
                    # Handle guidance
         | 
| 632 | 
             
                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
         | 
| 633 |  | 
| 634 | 
            +
                    # 6. Denoising loop
         | 
| 635 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 636 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 637 | 
            +
                            if self.interrupt:
         | 
| 638 | 
            +
                                continue
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                        timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                        noise_pred = self.transformer(
         | 
| 645 | 
            +
                            hidden_states=latent_model_input,
         | 
| 646 | 
            +
                            timestep=timestep / 1000,
         | 
| 647 | 
            +
                            guidance=guidance,
         | 
| 648 | 
            +
                            pooled_projections=pooled_prompt_embeds,
         | 
| 649 | 
            +
                            encoder_hidden_states=prompt_embeds,
         | 
| 650 | 
            +
                            txt_ids=text_ids,
         | 
| 651 | 
            +
                            img_ids=latent_image_ids,
         | 
| 652 | 
            +
                            joint_attention_kwargs=self.joint_attention_kwargs,
         | 
| 653 | 
            +
                            return_dict=False,
         | 
| 654 | 
            +
                          )[0]
         | 
| 655 | 
            +
                        
         | 
| 656 | 
            +
                        noise_pred_uncond = self.transformer(
         | 
| 657 | 
            +
                            hidden_states=latents,
         | 
| 658 | 
            +
                            timestep=timestep / 1000,
         | 
| 659 | 
            +
                            guidance=guidance,
         | 
| 660 | 
            +
                            pooled_projections=negative_pooled_prompt_embeds,
         | 
| 661 | 
            +
                            encoder_hidden_states=negative_prompt_embeds,
         | 
| 662 | 
            +
                            img_ids=latent_image_ids,
         | 
| 663 | 
            +
                            joint_attention_kwargs=self.joint_attention_kwargs,
         | 
| 664 | 
            +
                            return_dict=False,
         | 
| 665 | 
            +
                          )[0]
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                        if self.do_classifier_free_guidance:
         | 
| 668 | 
            +
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 669 | 
            +
                        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                        latents_dtype = latents.dtype
         | 
| 672 | 
            +
                        latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
         | 
| 673 | 
            +
                         # Yield intermediate result
         | 
| 674 | 
            +
                        torch.cuda.empty_cache()
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    if latents.dtype != latents_dtype:
         | 
| 677 | 
            +
                         if torch.backends.mps.is_available():
         | 
| 678 | 
            +
                                # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
         | 
| 679 | 
            +
                            latents = latents.to(latents_dtype)
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                    if callback_on_step_end is not None:
         | 
| 682 | 
            +
                        callback_kwargs = {}
         | 
| 683 | 
            +
                         for k in callback_on_step_end_tensor_inputs:
         | 
| 684 | 
            +
                            callback_kwargs[k] = locals()[k]
         | 
| 685 | 
            +
                            callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                        latents = callback_outputs.pop("latents", latents)
         | 
| 688 | 
            +
                        prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         | 
| 689 | 
            +
                        negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
         | 
| 690 | 
            +
                        negative_pooled_prompt_embeds = callback_outputs.pop(
         | 
| 691 | 
            +
                                "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
         | 
| 692 | 
            +
                            )
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                        # call the callback, if provided
         | 
| 695 | 
            +
                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 696 | 
            +
                        progress_bar.update()
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    # Final image
         | 
| 699 | 
            +
                    return self._decode_latents_to_image(latents, height, width, output_type)
         | 
| 700 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 701 | 
            +
                    torch.cuda.empty_cache()
         | 
| 702 | 
            +
                
         | 
| 703 | 
            +
                def __call__(
         | 
| 704 | 
            +
                    self,
         | 
| 705 | 
            +
                    prompt: Union[str, List[str]] = None,
         | 
| 706 | 
            +
                    prompt_2: Optional[Union[str, List[str]]] = None,
         | 
| 707 | 
            +
                    height: Optional[int] = None,
         | 
| 708 | 
            +
                    width: Optional[int] = None,
         | 
| 709 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 710 | 
            +
                    negative_prompt_2: Optional[Union[str, List[str]]] = None,
         | 
| 711 | 
            +
                    num_inference_steps: int = 8,
         | 
| 712 | 
            +
                    timesteps: List[int] = None,
         | 
| 713 | 
            +
                    eta: float = 0.0,
         | 
| 714 | 
            +
                    guidance_scale: float = 3.5,
         | 
| 715 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 716 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 717 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 718 | 
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 719 | 
            +
                    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 720 | 
            +
                    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 721 | 
            +
                    negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 722 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 723 | 
            +
                    return_dict: bool = True,
         | 
| 724 | 
            +
                    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 725 | 
            +
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         | 
| 726 | 
            +
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         | 
| 727 | 
            +
                    clip_skip: Optional[int] = None,
         | 
| 728 | 
            +
                    max_sequence_length: int = 300,
         | 
| 729 | 
            +
                ):
         | 
| 730 | 
            +
                    height = height or self.default_sample_size * self.vae_scale_factor
         | 
| 731 | 
            +
                    width = width or self.default_sample_size * self.vae_scale_factor
         | 
| 732 | 
            +
                    
         | 
| 733 | 
            +
                    # 1. Check inputs
         | 
| 734 | 
            +
                    self.check_inputs(
         | 
| 735 | 
            +
                        prompt,
         | 
| 736 | 
            +
                        prompt_2,
         | 
| 737 | 
            +
                        height,
         | 
| 738 | 
            +
                        width,
         | 
| 739 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 740 | 
            +
                        negative_prompt_2=negative_prompt_2,
         | 
| 741 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 742 | 
            +
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 743 | 
            +
                        pooled_prompt_embeds=pooled_prompt_embeds,
         | 
| 744 | 
            +
                        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
         | 
| 745 | 
            +
                        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
         | 
| 746 | 
            +
                        max_sequence_length=max_sequence_length,
         | 
| 747 | 
            +
                        lora_scale=lora_scale
         | 
| 748 | 
            +
                    )
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                    self._guidance_scale = guidance_scale
         | 
| 751 | 
            +
                    self._clip_skip = clip_skip
         | 
| 752 | 
            +
                    self._joint_attention_kwargs = joint_attention_kwargs
         | 
| 753 | 
            +
                    self._interrupt = False
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                    # 2. Define call parameters
         | 
| 756 | 
            +
                    if prompt is not None and isinstance(prompt, str):
         | 
| 757 | 
            +
                        batch_size = 1
         | 
| 758 | 
            +
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 759 | 
            +
                        batch_size = len(prompt)
         | 
| 760 | 
            +
                    else:
         | 
| 761 | 
            +
                        batch_size = prompt_embeds.shape[0]
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                    device = self._execution_device
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    lora_scale = (
         | 
| 768 | 
            +
                        self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
         | 
| 769 | 
            +
                    )
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    if self.do_classifier_free_guidance:
         | 
| 772 | 
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         | 
| 773 | 
            +
                        pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                    # 4. Prepare latent variables
         | 
| 776 | 
            +
                    num_channels_latents = self.transformer.config.in_channels // 4
         | 
| 777 | 
            +
                    latents, latent_image_ids = self.prepare_latents(
         | 
| 778 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 779 | 
            +
                        num_channels_latents,
         | 
| 780 | 
            +
                        height,
         | 
| 781 | 
            +
                        width,
         | 
| 782 | 
            +
                        prompt_embeds.dtype,
         | 
| 783 | 
            +
                        negative_prompt_embeds.dtype,
         | 
| 784 | 
            +
                        device,
         | 
| 785 | 
            +
                        generator,
         | 
| 786 | 
            +
                        latents,
         | 
| 787 | 
            +
                    )
         | 
| 788 | 
            +
                    
         | 
| 789 | 
            +
                    # 5. Prepare timesteps
         | 
| 790 | 
            +
                    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
         | 
| 791 | 
            +
                    image_seq_len = latents.shape[1]
         | 
| 792 | 
            +
                    mu = calculate_timestep_shift(image_seq_len)
         | 
| 793 | 
            +
                    timesteps, num_inference_steps = prepare_timesteps(
         | 
| 794 | 
            +
                        self.scheduler,
         | 
| 795 | 
            +
                        num_inference_steps,
         | 
| 796 | 
            +
                        device,
         | 
| 797 | 
            +
                        timesteps,
         | 
| 798 | 
            +
                        sigmas,
         | 
| 799 | 
            +
                        mu=mu,
         | 
| 800 | 
            +
                    )
         | 
| 801 | 
            +
                    self._num_timesteps = len(timesteps)
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                    # Handle guidance
         | 
| 804 | 
            +
                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
         | 
| 805 | 
            +
             | 
| 806 | 
             
                    # 6. Denoising loop
         | 
| 807 | 
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 808 | 
             
                        for i, t in enumerate(timesteps):
         | 
|  | |
| 845 | 
             
                         # Yield intermediate result
         | 
| 846 | 
             
                        torch.cuda.empty_cache()
         | 
| 847 |  | 
| 848 | 
            +
                    if latents.dtype != latents_dtype:
         | 
| 849 | 
            +
                        if torch.backends.mps.is_available():
         | 
| 850 | 
             
                                # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
         | 
| 851 | 
            +
                            latents = latents.to(latents_dtype)
         | 
| 852 |  | 
| 853 | 
            +
                    if callback_on_step_end is not None:
         | 
| 854 | 
            +
                        callback_kwargs = {}
         | 
| 855 | 
             
                            for k in callback_on_step_end_tensor_inputs:
         | 
| 856 | 
             
                                callback_kwargs[k] = locals()[k]
         | 
| 857 | 
             
                            callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         | 
| 858 |  | 
| 859 | 
            +
                        latents = callback_outputs.pop("latents", latents)
         | 
| 860 | 
             
                            prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         | 
| 861 | 
             
                            negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
         | 
| 862 | 
             
                            negative_pooled_prompt_embeds = callback_outputs.pop(
         | 
|  | |
| 864 | 
             
                            )
         | 
| 865 |  | 
| 866 | 
             
                        # call the callback, if provided
         | 
| 867 | 
            +
                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 868 | 
             
                            progress_bar.update()
         | 
| 869 | 
            +
                        # Final image
         | 
| 870 | 
            +
                    
         | 
| 871 | 
             
                    return self._decode_latents_to_image(latents, height, width, output_type)
         | 
| 872 | 
             
                    self.maybe_free_model_hooks()
         | 
| 873 | 
             
                    torch.cuda.empty_cache()
         | 
