Spaces:
Runtime error
Runtime error
Update pipeline.py
Browse files- pipeline.py +12 -27
pipeline.py
CHANGED
|
@@ -625,35 +625,20 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 625 |
if torch.backends.mps.is_available():
|
| 626 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 627 |
latents = latents.to(latents_dtype)
|
| 628 |
-
|
| 629 |
-
if callback_on_step_end is not None:
|
| 630 |
-
callback_kwargs = {}
|
| 631 |
-
for k in callback_on_step_end_tensor_inputs:
|
| 632 |
-
callback_kwargs[k] = locals()[k]
|
| 633 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 634 |
-
|
| 635 |
-
latents = callback_outputs.pop("latents", latents)
|
| 636 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 637 |
|
|
|
|
| 638 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 639 |
progress_bar.update()
|
| 640 |
-
|
| 641 |
-
if XLA_AVAILABLE:
|
| 642 |
-
xm.mark_step()
|
| 643 |
-
|
| 644 |
-
if output_type == "latent":
|
| 645 |
-
image = latents
|
| 646 |
-
|
| 647 |
-
else:
|
| 648 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 649 |
-
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 650 |
-
image = self.vae.decode(latents, return_dict=False)[0]
|
| 651 |
-
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 652 |
|
| 653 |
-
#
|
|
|
|
| 654 |
self.maybe_free_model_hooks()
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
if torch.backends.mps.is_available():
|
| 626 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 627 |
latents = latents.to(latents_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
+
# call the callback, if provided
|
| 630 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 631 |
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
+
# Final image
|
| 634 |
+
return self._decode_latents_to_image(latents, height, width, output_type)
|
| 635 |
self.maybe_free_model_hooks()
|
| 636 |
+
torch.cuda.empty_cache()
|
| 637 |
+
|
| 638 |
+
def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
|
| 639 |
+
"""Decodes the given latents into an image."""
|
| 640 |
+
vae = vae or self.vae
|
| 641 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 642 |
+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
| 643 |
+
image = vae.decode(latents, return_dict=False)[0]
|
| 644 |
+
return self.image_processor.postprocess(image, output_type=output_type)[0]
|