Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
A general framework for various sampling algorithm from a diffusion model. | |
Impl based on | |
* Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157 | |
* also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler. | |
Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \ | |
adding new step function in get_runge_kutta_fn or get_multi_step_fn. | |
""" | |
import math | |
from typing import Any, Callable, List, Literal, Optional, Tuple, Union | |
import attrs | |
import torch | |
from cosmos_predict1.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported | |
from cosmos_predict1.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported | |
from cosmos_predict1.utils.config import make_freezable | |
COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] | |
class SolverConfig: | |
is_multi: bool = False | |
rk: str = "2mid" | |
multistep: str = "2ab" | |
# following parameters control stochasticity, see EDM paper | |
# BY default, we use deterministic with no stochasticity | |
s_churn: float = 0.0 | |
s_t_max: float = float("inf") | |
s_t_min: float = 0.05 | |
s_noise: float = 1.0 | |
class SolverTimestampConfig: | |
nfe: int = 50 | |
t_min: float = 0.002 | |
t_max: float = 80.0 | |
order: float = 7.0 | |
is_forward: bool = False # whether generate forward or backward timestamps | |
class SamplerConfig: | |
solver: SolverConfig = attrs.field(factory=SolverConfig) | |
timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig) | |
sample_clean: bool = True # whether run one last step to generate clean image | |
def get_rev_ts( | |
t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False | |
) -> torch.Tensor: | |
""" | |
Generate a sequence of reverse time steps. | |
Args: | |
t_min (float): The minimum time value. | |
t_max (float): The maximum time value. | |
num_steps (int): The number of time steps to generate. | |
ts_order (Union[int, float]): The order of the time step progression. | |
is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False. | |
Returns: | |
torch.Tensor: A tensor containing the generated time steps in reverse or forward order. | |
Raises: | |
ValueError: If `t_min` is not less than `t_max`. | |
TypeError: If `ts_order` is not an integer or float. | |
""" | |
if t_min >= t_max: | |
raise ValueError("t_min must be less than t_max") | |
if not isinstance(ts_order, (int, float)): | |
raise TypeError("ts_order must be an integer or float") | |
step_indices = torch.arange(num_steps + 1, dtype=torch.float64) | |
time_steps = ( | |
t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order)) | |
) ** ts_order | |
if is_forward: | |
return time_steps.flip(dims=(0,)) | |
return time_steps | |
class Sampler(torch.nn.Module): | |
def __init__(self, cfg: Optional[SamplerConfig] = None): | |
super().__init__() | |
if cfg is None: | |
cfg = SamplerConfig() | |
self.cfg = cfg | |
def forward( | |
self, | |
x0_fn: Callable, | |
x_sigma_max: torch.Tensor, | |
num_steps: int = 35, | |
sigma_min: float = 0.002, | |
sigma_max: float = 80, | |
rho: float = 7, | |
S_churn: float = 0, | |
S_min: float = 0, | |
S_max: float = float("inf"), | |
S_noise: float = 1, | |
solver_option: str = "2ab", | |
) -> torch.Tensor: | |
in_dtype = x_sigma_max.dtype | |
def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor: | |
return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64) | |
is_multistep = is_multi_step_fn_supported(solver_option) | |
is_rk = is_runge_kutta_fn_supported(solver_option) | |
assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}" | |
solver_cfg = SolverConfig( | |
s_churn=S_churn, | |
s_t_max=S_max, | |
s_t_min=S_min, | |
s_noise=S_noise, | |
is_multi=is_multistep, | |
rk=solver_option, | |
multistep=solver_option, | |
) | |
timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) | |
sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) | |
return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) | |
def _forward_impl( | |
self, | |
denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], | |
noisy_input_B_StateShape: torch.Tensor, | |
sampler_cfg: Optional[SamplerConfig] = None, | |
callback_fns: Optional[List[Callable]] = None, | |
) -> torch.Tensor: | |
""" | |
Internal implementation of the forward pass. | |
Args: | |
denoiser_fn: Function to denoise the input. | |
noisy_input_B_StateShape: Input tensor with noise. | |
sampler_cfg: Configuration for the sampler. | |
callback_fns: List of callback functions to be called during sampling. | |
Returns: | |
torch.Tensor: Denoised output tensor. | |
""" | |
sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg | |
solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0]) | |
num_timestamps = sampler_cfg.timestamps.nfe // solver_order | |
sigmas_L = get_rev_ts( | |
sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order | |
).to(noisy_input_B_StateShape.device) | |
denoised_output = differential_equation_solver( | |
denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns | |
)(noisy_input_B_StateShape) | |
if sampler_cfg.sample_clean: | |
# Override denoised_output with fully denoised version | |
ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) | |
denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) | |
return denoised_output | |
def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: | |
""" | |
Implements a for loop with a function. | |
Args: | |
lower: Lower bound of the loop (inclusive). | |
upper: Upper bound of the loop (exclusive). | |
body_fun: Function to be applied in each iteration. | |
init_val: Initial value for the loop. | |
Returns: | |
The final result after all iterations. | |
""" | |
val = init_val | |
for i in range(lower, upper): | |
val = body_fun(i, val) | |
return val | |
def differential_equation_solver( | |
x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], | |
sigmas_L: torch.Tensor, | |
solver_cfg: SolverConfig, | |
callback_fns: Optional[List[Callable]] = None, | |
) -> Callable[[torch.Tensor], torch.Tensor]: | |
""" | |
Creates a differential equation solver function. | |
Args: | |
x0_fn: Function to compute x0 prediction. | |
sigmas_L: Tensor of sigma values with shape [L,]. | |
solver_cfg: Configuration for the solver. | |
callback_fns: Optional list of callback functions. | |
Returns: | |
A function that solves the differential equation. | |
""" | |
num_step = len(sigmas_L) - 1 | |
if solver_cfg.is_multi: | |
update_step_fn = get_multi_step_fn(solver_cfg.multistep) | |
else: | |
update_step_fn = get_runge_kutta_fn(solver_cfg.rk) | |
eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) | |
def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: | |
""" | |
Samples from the differential equation. | |
Args: | |
input_xT_B_StateShape: Input tensor with shape [B, StateShape]. | |
Returns: | |
Output tensor with shape [B, StateShape]. | |
""" | |
ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64) | |
def step_fn( | |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] | |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: | |
input_x_B_StateShape, x0_preds = state | |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] | |
# algorithm 2: line 4-6 | |
if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max: | |
hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 | |
input_x_B_StateShape = input_x_B_StateShape + ( | |
hat_sigma_cur_0**2 - sigma_cur_0**2 | |
).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape) | |
sigma_cur_0 = hat_sigma_cur_0 | |
if solver_cfg.is_multi: | |
x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) | |
output_x_B_StateShape, x0_preds = update_step_fn( | |
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds | |
) | |
else: | |
output_x_B_StateShape, x0_preds = update_step_fn( | |
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn | |
) | |
if callback_fns: | |
for callback_fn in callback_fns: | |
callback_fn(**locals()) | |
return output_x_B_StateShape, x0_preds | |
x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) | |
return x_at_eps | |
return sample_fn | |