Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import Dict, Any | |
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3 | |
from src.flair.pipelines import utils | |
import tqdm | |
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline): | |
def to(self, device, kwargs): | |
self.transformer.to(device) | |
self.vae.to(device) | |
return self | |
def get_timesteps(self, n_steps, device, ts_min=0): | |
# Create a linear schedule for timesteps | |
timesteps = torch.linspace(1, ts_min, n_steps+2, device=device, dtype=torch.float32) | |
return timesteps[1:-1] # Exclude the first and last timesteps | |
def single_step( | |
self, | |
img_latent: torch.Tensor, | |
t: torch.Tensor, | |
kwargs: Dict[str, Any], | |
is_noised_latent = False, | |
): | |
if "noise" in kwargs: | |
noise = kwargs["noise"].detach() | |
alpha = kwargs["inv_alpha"] | |
if alpha == "tsqrt": | |
alpha = t**0.5 # * 0.75 | |
elif alpha == "t": | |
alpha = t | |
elif alpha == "sine": | |
alpha = torch.sin(t * 3.141592653589793/2) | |
elif alpha == "1-t": | |
alpha = 1 - t | |
elif alpha == "1-t*0.5": | |
alpha = (1 - t)*0.5 | |
elif alpha == "1-t*0.9": | |
alpha = (1 - t)*0.9 | |
elif alpha == "t**1/3": | |
alpha = t**(1/3) | |
elif alpha == "(1-t)**0.5": | |
alpha = (1-t)**0.5 | |
elif alpha == "((1-t)*0.8)**0.5": | |
alpha = (1-t*0.8)**0.5 | |
elif alpha == "(1-t)**2": | |
alpha = (1-t)**2 | |
# alpha = t * kwargs["inv_alpha"] | |
noise = (alpha) * noise + (1-alpha**2)**0.5 * torch.randn_like(img_latent) | |
# noise = noise / noise.std() | |
# noise = noise / (1- 2*alpha*(1-alpha))**0.5 | |
# noise = noise + alpha * torch.randn_like(img_latent) | |
else: | |
noise = torch.randn_like(img_latent) | |
if is_noised_latent: | |
noised_latent = img_latent | |
else: | |
noised_latent = t * noise + (1 - t) * img_latent | |
latent_model_input = torch.cat([noised_latent] * 2) if self.do_classifier_free_guidance else noised_latent | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = t.expand(latent_model_input.shape[0]) | |
noise_pred = self.transformer( | |
hidden_states=latent_model_input.to(img_latent.dtype), | |
timestep=(timestep*1000).to(img_latent.dtype), | |
encoder_hidden_states=kwargs["prompt_embeds"].repeat(img_latent.shape[0], 1, 1), | |
pooled_projections=kwargs["pooled_prompt_embeds"].repeat(img_latent.shape[0], 1), | |
joint_attention_kwargs=None, | |
return_dict=False, | |
)[0] | |
if self.do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
eps = utils.v_to_eps(noise_pred, t, noised_latent) | |
return eps, noise, (1 - t), t, noise_pred | |
def encode(self, img): | |
# Encode the image into latent space | |
img_latent = self.vae.encode(img, return_dict=False)[0] | |
if hasattr(img_latent, "sample"): | |
img_latent = img_latent.sample() | |
img_latent = (img_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor | |
return img_latent | |
def decode(self, img_latent): | |
# Decode the latent representation back to image space | |
img = self.vae.decode(img_latent / self.vae.config.scaling_factor + self.vae.config.shift_factor, return_dict=False)[0] | |
return img | |
def denoise(self, pseudo_inv, kwargs, inverse=False): | |
# get timesteps | |
timesteps = torch.linspace(1, 0, kwargs["n_steps"], device=pseudo_inv.device, dtype=pseudo_inv.dtype) | |
sigmas = timesteps | |
if inverse: | |
timesteps = timesteps.flip(0) | |
sigmas = sigmas.flip(0) | |
# make a single step | |
for i, t in tqdm.tqdm(enumerate(timesteps[:-1]), desc="Denoising", total=len(timesteps)-1): | |
eps, noise, _, t, v = self.single_step( | |
pseudo_inv, | |
t.to("cuda")*1000, | |
kwargs, | |
is_noised_latent=True, | |
) | |
# step | |
sigma_next = sigmas[i+1] | |
sigma_t = sigmas[i] | |
pseudo_inv = pseudo_inv + v * (sigma_next - sigma_t) | |
return pseudo_inv |