qitaoz's picture
Upload 57 files
4562a06 verified
import ipdb # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from diffusionsfm.utils.visualization import plot_to_image
class NoiseScheduler(nn.Module):
def __init__(
self,
max_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
cos_power=2,
num_inference_steps=100,
type="linear",
):
super().__init__()
self.max_timesteps = max_timesteps
self.num_inference_steps = num_inference_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.cos_power = cos_power
self.type = type
if type == "linear":
self.register_linear_schedule()
elif type == "cosine":
self.register_cosine_schedule(cos_power)
elif type == "scaled_linear":
self.register_scaled_linear_schedule()
self.inference_timesteps = self.compute_inference_timesteps()
def register_linear_schedule(self):
# zero terminal SNR (https://arxiv.org/pdf/2305.08891)
betas = torch.linspace(
self.beta_start,
self.beta_end,
self.max_timesteps,
dtype=torch.float32,
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
alphas_bar_sqrt -= alphas_bar_sqrt_T
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
self.register_buffer(
"betas",
betas,
)
self.register_buffer("alphas", 1.0 - self.betas)
self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
def register_cosine_schedule(self, cos_power, s=0.008):
timesteps = (
torch.arange(self.max_timesteps + 1, dtype=torch.float32)
/ self.max_timesteps
)
alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2
alpha_bars = torch.cos(alpha_bars).pow(cos_power)
alpha_bars = alpha_bars / alpha_bars[0]
betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
self.register_buffer(
"betas",
betas,
)
self.register_buffer("alphas", 1.0 - betas)
self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
def register_scaled_linear_schedule(self):
self.register_buffer(
"betas",
torch.linspace(
self.beta_start**0.5,
self.beta_end**0.5,
self.max_timesteps,
dtype=torch.float32,
)
** 2,
)
self.register_buffer("alphas", 1.0 - self.betas)
self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
def compute_inference_timesteps(
self, num_inference_steps=None, num_train_steps=None
):
# based on diffusers's scheduling code
if num_inference_steps is None:
num_inference_steps = self.num_inference_steps
if num_train_steps is None:
num_train_steps = self.max_timesteps
step_ratio = num_train_steps // num_inference_steps
timesteps = (
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int)
)
return timesteps
def plot_schedule(self, return_image=False):
fig = plt.figure(figsize=(6, 4), dpi=100)
alpha_bars = self.alphas_cumprod.cpu().numpy()
plt.plot(np.sqrt(alpha_bars))
plt.grid()
if self.type == "linear":
plt.title(
f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})"
)
else:
self.type == "cosine"
plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})")
if return_image:
image = plot_to_image(fig)
plt.close(fig)
return image