|
""" |
|
Flow Euler Samplers for Generative Models |
|
|
|
This file implements samplers for flow-matching generative models using the Euler integration method. |
|
It contains three main sampler classes: |
|
1. FlowEulerSampler: Base implementation of Euler sampling for flow-matching models |
|
2. FlowEulerCfgSampler: Adds classifier-free guidance to the Euler sampler |
|
3. FlowEulerGuidanceIntervalSampler: Enhances the sampler with both classifier-free guidance and guidance intervals |
|
|
|
Flow-matching models define continuous paths from noise to data, and these samplers implement |
|
ODE solvers (specifically Euler method) to follow these paths and generate samples. |
|
""" |
|
|
|
from typing import * |
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
from easydict import EasyDict as edict |
|
from .base import Sampler |
|
from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin |
|
from .guidance_interval_mixin import GuidanceIntervalSamplerMixin |
|
|
|
|
|
class FlowEulerSampler(Sampler): |
|
""" |
|
Generate samples from a flow-matching model using Euler sampling. |
|
|
|
Args: |
|
sigma_min: The minimum scale of noise in flow. |
|
""" |
|
def __init__( |
|
self, |
|
sigma_min: float, |
|
): |
|
|
|
self.sigma_min = sigma_min |
|
|
|
def _eps_to_xstart(self, x_t, t, eps): |
|
""" |
|
Convert noise prediction (epsilon) to predicted clean data (x_0). |
|
|
|
Args: |
|
x_t: Current noisy tensor at timestep t |
|
t: Current timestep |
|
eps: Predicted noise |
|
|
|
Returns: |
|
Predicted clean data x_0 |
|
""" |
|
assert x_t.shape == eps.shape |
|
return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) |
|
|
|
def _xstart_to_eps(self, x_t, t, x_0): |
|
""" |
|
Convert predicted clean data (x_0) to noise prediction (epsilon). |
|
|
|
Args: |
|
x_t: Current noisy tensor at timestep t |
|
t: Current timestep |
|
x_0: Predicted clean data |
|
|
|
Returns: |
|
Implied noise prediction epsilon |
|
""" |
|
assert x_t.shape == x_0.shape |
|
return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) |
|
|
|
def _v_to_xstart_eps(self, x_t, t, v): |
|
""" |
|
Convert velocity prediction (v) to predicted clean data (x_0) and noise (epsilon). |
|
|
|
Args: |
|
x_t: Current noisy tensor at timestep t |
|
t: Current timestep |
|
v: Predicted velocity |
|
|
|
Returns: |
|
Tuple of (x_0, epsilon) derived from velocity |
|
""" |
|
assert x_t.shape == v.shape |
|
eps = (1 - t) * v + x_t |
|
x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v |
|
return x_0, eps |
|
|
|
def _inference_model(self, model, x_t, t, cond=None, **kwargs): |
|
""" |
|
Run inference with the model. |
|
|
|
Args: |
|
model: The flow model |
|
x_t: Current noisy tensor at timestep t |
|
t: Current timestep (will be scaled by 1000) |
|
cond: Conditional information |
|
kwargs: Additional arguments for model |
|
|
|
Returns: |
|
Model's predicted velocity |
|
""" |
|
|
|
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) |
|
|
|
|
|
if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1: |
|
cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1))) |
|
|
|
return model(x_t, t, cond, **kwargs) |
|
|
|
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): |
|
""" |
|
Get model predictions and convert to various formats. |
|
|
|
Args: |
|
model: The flow model |
|
x_t: Current noisy tensor at timestep t |
|
t: Current timestep |
|
cond: Conditional information |
|
kwargs: Additional arguments for model |
|
|
|
Returns: |
|
Tuple of (x_0, epsilon, velocity) predictions |
|
""" |
|
pred_v = self._inference_model(model, x_t, t, cond, **kwargs) |
|
pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) |
|
return pred_x_0, pred_eps, pred_v |
|
|
|
@torch.no_grad() |
|
def sample_once( |
|
self, |
|
model, |
|
x_t, |
|
t: float, |
|
t_prev: float, |
|
cond: Optional[Any] = None, |
|
**kwargs |
|
): |
|
""" |
|
Sample x_{t-1} from the model using Euler method. |
|
|
|
Args: |
|
model: The model to sample from. |
|
x_t: The [N x C x ...] tensor of noisy inputs at time t. |
|
t: The current timestep. |
|
t_prev: The previous timestep. |
|
cond: conditional information. |
|
**kwargs: Additional arguments for model inference. |
|
|
|
Returns: |
|
a dict containing the following |
|
- 'pred_x_prev': x_{t-1}. |
|
- 'pred_x_0': a prediction of x_0. |
|
""" |
|
|
|
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) |
|
|
|
pred_x_prev = x_t - (t - t_prev) * pred_v |
|
return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
model, |
|
noise, |
|
cond: Optional[Any] = None, |
|
steps: int = 50, |
|
rescale_t: float = 1.0, |
|
verbose: bool = True, |
|
**kwargs |
|
): |
|
""" |
|
Generate samples from the model using Euler method. |
|
|
|
Args: |
|
model: The model to sample from. |
|
noise: The initial noise tensor. |
|
cond: conditional information. |
|
steps: The number of steps to sample. |
|
rescale_t: The rescale factor for t. |
|
verbose: If True, show a progress bar. |
|
**kwargs: Additional arguments for model_inference. |
|
|
|
Returns: |
|
a dict containing the following |
|
- 'samples': the model samples. |
|
- 'pred_x_t': a list of prediction of x_t. |
|
- 'pred_x_0': a list of prediction of x_0. |
|
""" |
|
sample = noise |
|
|
|
t_seq = np.linspace(1, 0, steps + 1) |
|
|
|
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) |
|
|
|
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) |
|
|
|
|
|
ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) |
|
|
|
|
|
for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): |
|
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) |
|
sample = out.pred_x_prev |
|
ret.pred_x_t.append(out.pred_x_prev) |
|
ret.pred_x_0.append(out.pred_x_0) |
|
|
|
ret.samples = sample |
|
return ret |
|
|
|
|
|
class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): |
|
""" |
|
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. |
|
|
|
This class adds classifier-free guidance to the Euler sampler, enabling conditional |
|
generation with guidance strength control. |
|
""" |
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
model, |
|
noise, |
|
cond, |
|
neg_cond, |
|
steps: int = 50, |
|
rescale_t: float = 1.0, |
|
cfg_strength: float = 3.0, |
|
verbose: bool = True, |
|
**kwargs |
|
): |
|
""" |
|
Generate samples from the model using Euler method. |
|
|
|
Args: |
|
model: The model to sample from. |
|
noise: The initial noise tensor. |
|
cond: conditional information. |
|
neg_cond: negative conditional information. |
|
steps: The number of steps to sample. |
|
rescale_t: The rescale factor for t. |
|
cfg_strength: The strength of classifier-free guidance. |
|
verbose: If True, show a progress bar. |
|
**kwargs: Additional arguments for model_inference. |
|
|
|
Returns: |
|
a dict containing the following |
|
- 'samples': the model samples. |
|
- 'pred_x_t': a list of prediction of x_t. |
|
- 'pred_x_0': a list of prediction of x_0. |
|
""" |
|
|
|
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) |
|
|
|
|
|
class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): |
|
""" |
|
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. |
|
|
|
This class extends the Euler sampler with both classifier-free guidance and the ability |
|
to specify timestep intervals where guidance is applied. |
|
""" |
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
model, |
|
noise, |
|
cond, |
|
neg_cond, |
|
steps: int = 50, |
|
rescale_t: float = 1.0, |
|
cfg_strength: float = 3.0, |
|
cfg_interval: Tuple[float, float] = (0.0, 1.0), |
|
verbose: bool = True, |
|
**kwargs |
|
): |
|
""" |
|
Generate samples from the model using Euler method. |
|
|
|
Args: |
|
model: The model to sample from. |
|
noise: The initial noise tensor. |
|
cond: conditional information. |
|
neg_cond: negative conditional information. |
|
steps: The number of steps to sample. |
|
rescale_t: The rescale factor for t. |
|
cfg_strength: The strength of classifier-free guidance. |
|
cfg_interval: The interval for classifier-free guidance. |
|
verbose: If True, show a progress bar. |
|
**kwargs: Additional arguments for model_inference. |
|
|
|
Returns: |
|
a dict containing the following |
|
- 'samples': the model samples. |
|
- 'pred_x_t': a list of prediction of x_t. |
|
- 'pred_x_0': a list of prediction of x_0. |
|
""" |
|
|
|
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) |
|
|