import ipdb # noqa: F401 import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from diffusionsfm.utils.visualization import plot_to_image class NoiseScheduler(nn.Module): def __init__( self, max_timesteps=1000, beta_start=0.0001, beta_end=0.02, cos_power=2, num_inference_steps=100, type="linear", ): super().__init__() self.max_timesteps = max_timesteps self.num_inference_steps = num_inference_steps self.beta_start = beta_start self.beta_end = beta_end self.cos_power = cos_power self.type = type if type == "linear": self.register_linear_schedule() elif type == "cosine": self.register_cosine_schedule(cos_power) elif type == "scaled_linear": self.register_scaled_linear_schedule() self.inference_timesteps = self.compute_inference_timesteps() def register_linear_schedule(self): # zero terminal SNR (https://arxiv.org/pdf/2305.08891) betas = torch.linspace( self.beta_start, self.beta_end, self.max_timesteps, dtype=torch.float32, ) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() alphas_bar_sqrt -= alphas_bar_sqrt_T alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) alphas_bar = alphas_bar_sqrt**2 alphas = alphas_bar[1:] / alphas_bar[:-1] alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas self.register_buffer( "betas", betas, ) self.register_buffer("alphas", 1.0 - self.betas) self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) def register_cosine_schedule(self, cos_power, s=0.008): timesteps = ( torch.arange(self.max_timesteps + 1, dtype=torch.float32) / self.max_timesteps ) alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2 alpha_bars = torch.cos(alpha_bars).pow(cos_power) alpha_bars = alpha_bars / alpha_bars[0] betas = 1 - alpha_bars[1:] / alpha_bars[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) self.register_buffer( "betas", betas, ) self.register_buffer("alphas", 1.0 - betas) self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) def register_scaled_linear_schedule(self): self.register_buffer( "betas", torch.linspace( self.beta_start**0.5, self.beta_end**0.5, self.max_timesteps, dtype=torch.float32, ) ** 2, ) self.register_buffer("alphas", 1.0 - self.betas) self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) def compute_inference_timesteps( self, num_inference_steps=None, num_train_steps=None ): # based on diffusers's scheduling code if num_inference_steps is None: num_inference_steps = self.num_inference_steps if num_train_steps is None: num_train_steps = self.max_timesteps step_ratio = num_train_steps // num_inference_steps timesteps = ( (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int) ) return timesteps def plot_schedule(self, return_image=False): fig = plt.figure(figsize=(6, 4), dpi=100) alpha_bars = self.alphas_cumprod.cpu().numpy() plt.plot(np.sqrt(alpha_bars)) plt.grid() if self.type == "linear": plt.title( f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})" ) else: self.type == "cosine" plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})") if return_image: image = plot_to_image(fig) plt.close(fig) return image