Spaces:
Running
on
Zero
Running
on
Zero
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import ( | |
retrieve_timesteps, | |
retrieve_latents, | |
) | |
import torch | |
from functools import partial | |
from diffusers import DDPMScheduler | |
from model.pipeline_sdxl import StableDiffusionXLImg2ImgPipeline | |
SAMPLING_DEVICE = "cpu" # "cuda" | |
VAE_SAMPLE = "argmax" # "argmax" or "sample" | |
RESIZE_TYPE = None # Image.LANCZOS | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def encode_image(image, pipe, generator): | |
pipe_dtype = pipe.dtype | |
image = pipe.image_processor.preprocess(image) | |
image = image.to(device=device, dtype=pipe.dtype) | |
if pipe.vae.config.force_upcast: | |
image = image.float() | |
pipe.vae.to(dtype=torch.float32) | |
init_latents = retrieve_latents( | |
pipe.vae.encode(image), generator=generator, sample_mode=VAE_SAMPLE | |
) | |
if pipe.vae.config.force_upcast: | |
pipe.vae.to(pipe_dtype) | |
init_latents = init_latents.to(pipe_dtype) | |
init_latents = pipe.vae.config.scaling_factor * init_latents | |
return init_latents | |
def create_xts( | |
noise_shift_delta, | |
noise_timesteps, | |
generator, | |
scheduler, | |
timesteps, | |
x_0, | |
): | |
if noise_timesteps is None: | |
noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1]) | |
noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps] | |
# noise_timesteps = [timestep for timestep in timesteps] | |
# print(noise_timesteps, timesteps) | |
first_x_0_idx = len(noise_timesteps) | |
for i in range(len(noise_timesteps)): | |
if noise_timesteps[i] <= 0: | |
first_x_0_idx = i | |
break | |
noise_timesteps = noise_timesteps[:first_x_0_idx] | |
x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1) | |
noise = torch.randn( | |
x_0_expanded.size(), generator=generator, device=SAMPLING_DEVICE | |
).to(x_0.device) | |
x_ts = scheduler.add_noise( | |
x_0_expanded, | |
noise, | |
torch.IntTensor(noise_timesteps), | |
) | |
x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)] | |
x_ts += [x_0] * (len(timesteps) - first_x_0_idx) | |
x_ts += [x_0] | |
return x_ts | |
def load_pipeline(fp16, cache_dir): | |
kwargs = ( | |
{ | |
"torch_dtype": torch.float16, | |
"variant": "fp16", | |
} | |
if fp16 | |
else {} | |
) | |
from model.unet_sdxl import OursUNet2DConditionModel | |
unet = OursUNet2DConditionModel.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
subfolder="unet", | |
cache_dir=cache_dir, | |
safety_checker=None, | |
**kwargs, | |
) | |
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
unet=unet, | |
cache_dir=cache_dir, | |
safety_checker=None, | |
**kwargs, | |
) | |
pipeline = pipeline.to(device) | |
pipeline.scheduler = DDPMScheduler.from_pretrained( # type: ignore | |
"stabilityai/sdxl-turbo", | |
subfolder="scheduler", | |
) | |
return pipeline | |
def set_pipeline(pipeline: StableDiffusionXLImg2ImgPipeline, num_timesteps, generator, config): | |
if config.timesteps is None: | |
denoising_start = config.step_start / config.num_steps_inversion | |
timesteps, num_inference_steps = retrieve_timesteps( | |
pipeline.scheduler, config.num_steps_inversion, device, None | |
) | |
timesteps, num_inference_steps = pipeline.get_timesteps( | |
num_inference_steps=num_inference_steps, | |
device=device, | |
denoising_start=denoising_start, | |
strength=0, | |
) | |
timesteps = timesteps.type(torch.int64) | |
pipeline.__call__ = partial( | |
pipeline.__call__, | |
num_inference_steps=config.num_steps_inversion, | |
guidance_scale=config.guidance_scale, | |
generator=generator, | |
denoising_start=denoising_start, | |
strength=0, | |
) | |
pipeline.scheduler.set_timesteps( | |
timesteps=timesteps.cpu(), | |
) | |
else: | |
timesteps = torch.tensor(config.timesteps, dtype=torch.int64) | |
pipeline.__call__ = partial( | |
pipeline.__call__, | |
timesteps=timesteps, | |
guidance_scale=config.guidance_scale, | |
denoising_start=0, | |
strength=1, | |
) | |
pipeline.scheduler.set_timesteps( | |
timesteps=config.timesteps, # device=pipeline.device | |
) | |
timesteps = [torch.tensor(t) for t in timesteps.tolist()] | |
return timesteps, config | |