""" Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py """ from collections import defaultdict from typing import Dict, Union import torch from omegaconf import ListConfig, OmegaConf from tqdm import tqdm from einops import rearrange from ...modules.diffusionmodules.sampling_utils import ( get_ancestral_step, linear_multistep_coeff, to_d, to_neg_log_sigma, to_sigma, chunk_inputs, ) from ...util import append_dims, default, instantiate_from_config DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} class BaseDiffusionSampler: def __init__( self, discretization_config: Union[Dict, ListConfig, OmegaConf], num_steps: Union[int, None] = None, guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, verbose: bool = True, device: str = "cuda", ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) self.guider = instantiate_from_config( default( guider_config, DEFAULT_GUIDER, ) ) self.verbose = verbose self.device = device def set_num_steps(self, num_steps: int): self.num_steps = num_steps def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None, strength=1.0): print("Num steps: ", self.num_steps if num_steps is None else num_steps) sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) if strength != 1.0: init_timestep = min(int(len(sigmas) * strength), len(sigmas)) t_start = max(len(sigmas) - init_timestep, 0) # sigmas[:t_start] = torch.ones_like(sigmas[:t_start]) * sigmas[t_start] sigmas = sigmas[t_start:] uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def denoise(self, x, denoiser, sigma, cond, uc): denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) denoised = self.guider(denoised, sigma) return denoised def get_sigma_gen(self, num_sigmas): sigma_generator = range(num_sigmas - 1) if self.verbose: print("#" * 30, " Sampling setting ", "#" * 30) print(f"Sampler: {self.__class__.__name__}") print(f"Discretization: {self.discretization.__class__.__name__}") print(f"Guider: {self.guider.__class__.__name__}") sigma_generator = tqdm( sigma_generator, total=num_sigmas, desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", ) return sigma_generator class FIFODiffusionSampler(BaseDiffusionSampler): def __init__(self, lookahead=False, num_frames=14, num_partitions=4, *args, **kwargs): super().__init__(*args, **kwargs) self.num_frames = num_frames self.lookahead = lookahead self.num_partitions = num_partitions self.num_steps = self.num_frames * self.num_partitions self.fifo = [] def get_sigma_gen(self, num_sigmas, total_n_frames): total = total_n_frames + num_sigmas - self.num_frames sigma_generator = range(total_n_frames + num_sigmas - self.num_frames - 1) if self.verbose: print("#" * 30, " Sampling setting ", "#" * 30) print(f"Sampler: {self.__class__.__name__}") print(f"Discretization: {self.discretization.__class__.__name__}") print(f"Guider: {self.guider.__class__.__name__}") sigma_generator = tqdm( sigma_generator, total=total, desc=f"Sampling with {self.__class__.__name__} for {total} steps", ) return sigma_generator def prepare_sampling_loop(self, x, cond, uc=None): sigmas = self.discretization(self.num_steps, device=self.device) uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc class SingleStepDiffusionSampler(BaseDiffusionSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): raise NotImplementedError def euler_step(self, x, d, dt): return x + dt * d class EDMSampler(SingleStepDiffusionSampler): def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) if x.ndim == 5: denoised = rearrange(denoised, "(b t) c h w -> b c t h w", b=x.shape[0]) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None, strength=1.0): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps, strength=strength) for i in self.get_sigma_gen(num_sigmas): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma, ) return x class EDMSampleCFGplusplus(SingleStepDiffusionSampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): sigma_hat = sigma denoised, x_u = self.denoise(x, denoiser, sigma_hat, cond, uc) if x.ndim == 5: denoised = rearrange(denoised, "(b t) c h w -> b c t h w", b=x.shape[0]) x_u = rearrange(x_u, "(b t) c h w -> b c t h w", b=x.shape[0]) d = to_d(x, sigma_hat, x_u) dt = append_dims(next_sigma - sigma_hat, x.ndim) next_sigma = append_dims(next_sigma, x.ndim) euler_step = self.euler_step(denoised, d, next_sigma) x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None, strength=1.0): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps, strength=strength) for i in self.get_sigma_gen(num_sigmas): s_in = x.new_ones([x.shape[0]]) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, None, ) return x def shift_latents(latents): # shift latents latents[:, :, :-1] = latents[:, :, 1:].clone() # add new noise to the last frame latents[:, :, -1] = torch.randn_like(latents[:, :, -1]) return latents class FIFOEDMSampler(FIFODiffusionSampler): """ The problem is that the original implementation doesn't take into consideration the condition. So we need to check if this can work with the condition. Don't have time to check this now. """ def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise def euler_step(self, x, d, dt): return x + dt * d def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): return euler_step def concatenate_list_dict(self, dict1): for k, v in dict1.items(): if isinstance(v, list): dict1[k] = torch.cat(v, dim=0) else: dict1[k] = v return dict1 def prepare_latents(self, x, c, uc, sigmas, num_sigmas): latents_list = [] sigma_hat_list = [] sigma_next_list = [] c_list = defaultdict(list) uc_list = defaultdict(list) video = torch.load("/data/home/antoni/code/generative-models-dub/samples_z.pt") video = rearrange(video, "t c h w -> () c t h w") for k, v in c.items(): if not isinstance(v, torch.Tensor): c_list[k] = v uc_list[k] = uc[k] if self.lookahead: for i in range(self.num_frames // 2): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) sigma = sigmas[i] sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(video[:, :, [0]]) * self.s_noise latents = video[:, :, [0]] + eps * append_dims(sigma_hat**2 - sigma**2, video.ndim) ** 0.5 else: latents = video[:, :, [0]] for k, v in c.items(): if isinstance(v, torch.Tensor): c_list[k].append(v[[0]]) for k, v in uc.items(): if isinstance(v, torch.Tensor): uc_list[k].append(v[[0]]) latents_list.append(latents) sigma_hat_list.append(sigma_hat) sigma_next_list.append(sigmas[i + 1]) for i in range(num_sigmas - 1): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) sigma = sigmas[i] sigma_hat = sigma * (gamma + 1.0) frame_idx = max(0, i - (num_sigmas - self.num_frames)) print(frame_idx) if gamma > 0: eps = torch.randn_like(video[:, :, [frame_idx]]) * self.s_noise latents = video[:, :, [frame_idx]] + eps * append_dims(sigma_hat**2 - sigma**2, video.ndim) ** 0.5 else: latents = video[:, :, [frame_idx]] for k, v in c.items(): if isinstance(v, torch.Tensor): c_list[k].append( v[[frame_idx]] if v.shape[0] == video.shape[2] else v[[frame_idx // self.num_frames]] ) for k, v in uc.items(): if isinstance(v, torch.Tensor): uc_list[k].append( v[[frame_idx]] if v.shape[0] == video.shape[2] else v[[frame_idx // self.num_frames]] ) latents_list.append(latents) sigma_hat_list.append(sigma_hat) sigma_next_list.append(sigmas[i + 1]) latents = torch.cat(latents_list, dim=2) sigma_hat = torch.stack(sigma_hat_list, dim=0) sigma_next = torch.stack(sigma_next_list, dim=0) c_list = self.concatenate_list_dict(c_list) uc_list = self.concatenate_list_dict(uc_list) return latents, sigma_hat, sigma_next, c_list, uc_list def sampler_step(self, sigma_hat, next_sigma, denoiser, x, cond, uc=None): denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) if x.ndim == 5: x = rearrange(x, "b c t h w -> (b t) c h w") d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) return x def merge_cond_dict(self, cond, total_n_frames): for k, v in cond.items(): if not isinstance(v, torch.Tensor): cond[k] = v else: if v.dim() == 5: cond[k] = rearrange(v, "b c t h w -> (b t) c h w") elif v.dim() == 3 and v.shape[0] != total_n_frames: cond[k] = rearrange(v, "b t c -> (b t) () c") else: cond[k] = v return cond def __call__(self, denoiser, x, cond, uc=None, num_steps=None, strength=1.0): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc) x = rearrange(x, "b c h w -> () c b h w") cond = self.merge_cond_dict(cond, x.shape[2]) uc = self.merge_cond_dict(uc, x.shape[2]) total_n_frames = x.shape[2] latents, sigma_hat, sigma_next, cond, uc = self.prepare_latents(x, cond, uc, sigmas, num_sigmas) fifo_video_frames = [] for i in self.get_sigma_gen(num_sigmas, total_n_frames): for rank in reversed(range(2 * self.num_partitions if self.lookahead else self.num_partitions)): start_idx = rank * (self.num_frames // 2) if self.lookahead else rank * self.num_frames midpoint_idx = start_idx + self.num_frames // 2 end_idx = start_idx + self.num_frames chunk_x, sigma_hat_chunk, sigma_next_chunk, cond_chunk, uc_chunk = chunk_inputs( latents, cond, uc, sigma_hat, sigma_next, start_idx, end_idx, self.num_frames ) s_in = chunk_x.new_ones([chunk_x.shape[0]]) out = self.sampler_step( s_in * sigma_hat_chunk, s_in * sigma_next_chunk, denoiser, chunk_x, cond_chunk, uc=uc_chunk, ) if self.lookahead: latents[:, :, midpoint_idx:end_idx] = rearrange( out[-(self.num_frames // 2) :], "b c h w -> () c b h w" ) else: latents[:, :, start_idx:end_idx] = rearrange(out, "b c h w -> () c b h w") del out first_frame_idx = self.num_frames // 2 if self.lookahead else 0 latents = shift_latents(latents) fifo_video_frames.append(latents[:, :, [first_frame_idx]]) return rearrange(torch.cat(fifo_video_frames, dim=2), "() c b h w -> b c h w")[-total_n_frames:] class AncestralSampler(SingleStepDiffusionSampler): def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) self.eta = eta self.s_noise = s_noise self.noise_sampler = lambda x: torch.randn_like(x) def ancestral_euler_step(self, x, denoised, sigma, sigma_down): d = to_d(x, sigma, denoised) dt = append_dims(sigma_down - sigma, x.ndim) return self.euler_step(x, d, dt) def ancestral_step(self, x, sigma, next_sigma, sigma_up): x = torch.where( append_dims(next_sigma, x.ndim) > 0.0, x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), x, ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( self, order=4, *args, **kwargs, ): super().__init__(*args, **kwargs) self.order = order def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) ds = [] sigmas_cpu = sigmas.detach().cpu().numpy() for i in self.get_sigma_gen(num_sigmas): sigma = s_in * sigmas[i] denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) denoised = self.guider(denoised, sigma) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > self.order: ds.pop(0) cur_order = min(i + 1, self.order) coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class EulerEDMSampler(EDMSampler): def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): return euler_step class EulerEDMSamplerPlusPlus(EDMSampleCFGplusplus): def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): return euler_step class HeunEDMSampler(EDMSampler): def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): if torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 return euler_step else: denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) d_new = to_d(euler_step, next_sigma, denoised) d_prime = (d + d_new) / 2.0 # apply correction if noise level is not 0 x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) return x class EulerAncestralSampler(AncestralSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) denoised = self.denoise(x, denoiser, sigma, cond, uc) x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) x = self.ancestral_step(x, sigma, next_sigma, sigma_up) return x class DPMPP2SAncestralSampler(AncestralSampler): def get_variables(self, sigma, sigma_down): t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] h = t_next - t s = t + 0.5 * h return h, s, t, t_next def get_mult(self, h, s, t, t_next): mult1 = to_sigma(s) / to_sigma(t) mult2 = (-0.5 * h).expm1() mult3 = to_sigma(t_next) / to_sigma(t) mult4 = (-h).expm1() return mult1, mult2, mult3, mult4 def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) denoised = self.denoise(x, denoiser, sigma, cond, uc) x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) if torch.sum(sigma_down) < 1e-14: # Save a network evaluation if all noise levels are 0 x = x_euler else: h, s, t, t_next = self.get_variables(sigma, sigma_down) mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] x2 = mult[0] * x - mult[1] * denoised denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) x_dpmpp2s = mult[2] * x - mult[3] * denoised2 # apply correction if noise level is not 0 x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) x = self.ancestral_step(x, sigma, next_sigma, sigma_up) return x class DPMPP2MSampler(BaseDiffusionSampler): def get_variables(self, sigma, next_sigma, previous_sigma=None): t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] h = t_next - t if previous_sigma is not None: h_last = t - to_neg_log_sigma(previous_sigma) r = h_last / h return h, r, t, t_next else: return h, None, t, t_next def get_mult(self, h, r, t, t_next, previous_sigma): mult1 = to_sigma(t_next) / to_sigma(t) mult2 = (-h).expm1() if previous_sigma is not None: mult3 = 1 + 1 / (2 * r) mult4 = 1 / (2 * r) return mult1, mult2, mult3, mult4 else: return mult1, mult2 def sampler_step( self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] x_standard = mult[0] * x - mult[1] * denoised if old_denoised is None or torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 or on the first step return x_standard, denoised else: denoised_d = mult[2] * denoised - mult[3] * old_denoised x_advanced = mult[0] * x - mult[1] * denoised_d # apply correction if noise level is not 0 and not first step x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) old_denoised = None for i in self.get_sigma_gen(num_sigmas): x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * sigmas[i - 1], s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc=uc, ) return x