update for image batch processing
Browse files- pipeline.py +20 -5
pipeline.py
CHANGED
|
@@ -856,7 +856,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 856 |
)
|
| 857 |
|
| 858 |
# 4. Prepare image, and controlnet_conditioning_image
|
| 859 |
-
image = prepare_image(image)
|
|
|
|
| 860 |
|
| 861 |
# condition image(s)
|
| 862 |
if isinstance(self.controlnet, ControlNetModel):
|
|
@@ -897,15 +898,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 897 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 898 |
|
| 899 |
# 6. Prepare latent variables
|
| 900 |
-
latents = self.prepare_latents(
|
| 901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
latent_timestep,
|
| 903 |
batch_size,
|
| 904 |
num_images_per_prompt,
|
| 905 |
prompt_embeds.dtype,
|
| 906 |
device,
|
| 907 |
generator,
|
| 908 |
-
)
|
|
|
|
|
|
|
| 909 |
|
| 910 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 911 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
@@ -915,7 +928,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 915 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 916 |
for i, t in enumerate(timesteps):
|
| 917 |
# expand the latents if we are doing classifier free guidance
|
| 918 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
|
|
|
|
| 919 |
|
| 920 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 921 |
|
|
|
|
| 856 |
)
|
| 857 |
|
| 858 |
# 4. Prepare image, and controlnet_conditioning_image
|
| 859 |
+
# image = prepare_image(image)
|
| 860 |
+
images = [prepare_image(img) for img in image]
|
| 861 |
|
| 862 |
# condition image(s)
|
| 863 |
if isinstance(self.controlnet, ControlNetModel):
|
|
|
|
| 898 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 899 |
|
| 900 |
# 6. Prepare latent variables
|
| 901 |
+
# latents = self.prepare_latents(
|
| 902 |
+
# image,
|
| 903 |
+
# latent_timestep,
|
| 904 |
+
# batch_size,
|
| 905 |
+
# num_images_per_prompt,
|
| 906 |
+
# prompt_embeds.dtype,
|
| 907 |
+
# device,
|
| 908 |
+
# generator,
|
| 909 |
+
# )
|
| 910 |
+
|
| 911 |
+
latents = [self.prepare_latents(
|
| 912 |
+
img,
|
| 913 |
latent_timestep,
|
| 914 |
batch_size,
|
| 915 |
num_images_per_prompt,
|
| 916 |
prompt_embeds.dtype,
|
| 917 |
device,
|
| 918 |
generator,
|
| 919 |
+
) for img in images]
|
| 920 |
+
latents = torch.cat(latents)
|
| 921 |
+
|
| 922 |
|
| 923 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 924 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
|
| 928 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 929 |
for i, t in enumerate(timesteps):
|
| 930 |
# expand the latents if we are doing classifier free guidance
|
| 931 |
+
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 932 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents.clone()
|
| 933 |
+
|
| 934 |
|
| 935 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 936 |
|