Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,612 Bytes
90a9dd3 a7169e0 90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
from typing import Dict, Any
from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
from src.flair.pipelines import utils
import tqdm
class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
def to(self, device, kwargs):
self.transformer.to(device)
self.vae.to(device)
return self
def get_timesteps(self, n_steps, device, ts_min=0):
# Create a linear schedule for timesteps
timesteps = torch.linspace(1, ts_min, n_steps+2, device=device, dtype=torch.float32)
return timesteps[1:-1] # Exclude the first and last timesteps
def single_step(
self,
img_latent: torch.Tensor,
t: torch.Tensor,
kwargs: Dict[str, Any],
is_noised_latent = False,
):
if "noise" in kwargs:
noise = kwargs["noise"].detach()
alpha = kwargs["inv_alpha"]
if alpha == "tsqrt":
alpha = t**0.5 # * 0.75
elif alpha == "t":
alpha = t
elif alpha == "sine":
alpha = torch.sin(t * 3.141592653589793/2)
elif alpha == "1-t":
alpha = 1 - t
elif alpha == "1-t*0.5":
alpha = (1 - t)*0.5
elif alpha == "1-t*0.9":
alpha = (1 - t)*0.9
elif alpha == "t**1/3":
alpha = t**(1/3)
elif alpha == "(1-t)**0.5":
alpha = (1-t)**0.5
elif alpha == "((1-t)*0.8)**0.5":
alpha = (1-t*0.8)**0.5
elif alpha == "(1-t)**2":
alpha = (1-t)**2
# alpha = t * kwargs["inv_alpha"]
noise = (alpha) * noise + (1-alpha**2)**0.5 * torch.randn_like(img_latent)
# noise = noise / noise.std()
# noise = noise / (1- 2*alpha*(1-alpha))**0.5
# noise = noise + alpha * torch.randn_like(img_latent)
else:
noise = torch.randn_like(img_latent)
if is_noised_latent:
noised_latent = img_latent
else:
noised_latent = t * noise + (1 - t) * img_latent
latent_model_input = torch.cat([noised_latent] * 2) if self.do_classifier_free_guidance else noised_latent
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input.to(img_latent.dtype),
timestep=(timestep*1000).to(img_latent.dtype),
encoder_hidden_states=kwargs["prompt_embeds"].repeat(img_latent.shape[0], 1, 1),
pooled_projections=kwargs["pooled_prompt_embeds"].repeat(img_latent.shape[0], 1),
joint_attention_kwargs=None,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
eps = utils.v_to_eps(noise_pred, t, noised_latent)
return eps, noise, (1 - t), t, noise_pred
def encode(self, img):
# Encode the image into latent space
img_latent = self.vae.encode(img, return_dict=False)[0]
if hasattr(img_latent, "sample"):
img_latent = img_latent.sample()
img_latent = (img_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return img_latent
def decode(self, img_latent):
# Decode the latent representation back to image space
img = self.vae.decode(img_latent / self.vae.config.scaling_factor + self.vae.config.shift_factor, return_dict=False)[0]
return img
def denoise(self, pseudo_inv, kwargs, inverse=False):
# get timesteps
timesteps = torch.linspace(1, 0, kwargs["n_steps"], device=pseudo_inv.device, dtype=pseudo_inv.dtype)
sigmas = timesteps
if inverse:
timesteps = timesteps.flip(0)
sigmas = sigmas.flip(0)
# make a single step
for i, t in tqdm.tqdm(enumerate(timesteps[:-1]), desc="Denoising", total=len(timesteps)-1):
eps, noise, _, t, v = self.single_step(
pseudo_inv,
t.to("cuda")*1000,
kwargs,
is_noised_latent=True,
)
# step
sigma_next = sigmas[i+1]
sigma_t = sigmas[i]
pseudo_inv = pseudo_inv + v * (sigma_next - sigma_t)
return pseudo_inv |