omnipart's picture
init
491eded
"""
Flow Euler Samplers for Generative Models
This file implements samplers for flow-matching generative models using the Euler integration method.
It contains three main sampler classes:
1. FlowEulerSampler: Base implementation of Euler sampling for flow-matching models
2. FlowEulerCfgSampler: Adds classifier-free guidance to the Euler sampler
3. FlowEulerGuidanceIntervalSampler: Enhances the sampler with both classifier-free guidance and guidance intervals
Flow-matching models define continuous paths from noise to data, and these samplers implement
ODE solvers (specifically Euler method) to follow these paths and generate samples.
"""
from typing import *
import torch
import numpy as np
from tqdm import tqdm
from easydict import EasyDict as edict
from .base import Sampler
from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
class FlowEulerSampler(Sampler):
"""
Generate samples from a flow-matching model using Euler sampling.
Args:
sigma_min: The minimum scale of noise in flow.
"""
def __init__(
self,
sigma_min: float,
):
# sigma_min controls the minimum noise level in the flow
self.sigma_min = sigma_min
def _eps_to_xstart(self, x_t, t, eps):
"""
Convert noise prediction (epsilon) to predicted clean data (x_0).
Args:
x_t: Current noisy tensor at timestep t
t: Current timestep
eps: Predicted noise
Returns:
Predicted clean data x_0
"""
assert x_t.shape == eps.shape
return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
def _xstart_to_eps(self, x_t, t, x_0):
"""
Convert predicted clean data (x_0) to noise prediction (epsilon).
Args:
x_t: Current noisy tensor at timestep t
t: Current timestep
x_0: Predicted clean data
Returns:
Implied noise prediction epsilon
"""
assert x_t.shape == x_0.shape
return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
def _v_to_xstart_eps(self, x_t, t, v):
"""
Convert velocity prediction (v) to predicted clean data (x_0) and noise (epsilon).
Args:
x_t: Current noisy tensor at timestep t
t: Current timestep
v: Predicted velocity
Returns:
Tuple of (x_0, epsilon) derived from velocity
"""
assert x_t.shape == v.shape
eps = (1 - t) * v + x_t
x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
return x_0, eps
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
"""
Run inference with the model.
Args:
model: The flow model
x_t: Current noisy tensor at timestep t
t: Current timestep (will be scaled by 1000)
cond: Conditional information
kwargs: Additional arguments for model
Returns:
Model's predicted velocity
"""
# Scale timestep by 1000 for model input
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
# Broadcast single condition to match batch size if needed
# print(f"cond shape: {cond.shape}")
if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1:
cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1)))
# print(f"cond shape after repeat: {cond.shape}")
return model(x_t, t, cond, **kwargs)
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
"""
Get model predictions and convert to various formats.
Args:
model: The flow model
x_t: Current noisy tensor at timestep t
t: Current timestep
cond: Conditional information
kwargs: Additional arguments for model
Returns:
Tuple of (x_0, epsilon, velocity) predictions
"""
pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
return pred_x_0, pred_eps, pred_v
@torch.no_grad()
def sample_once(
self,
model,
x_t,
t: float,
t_prev: float,
cond: Optional[Any] = None,
**kwargs
):
"""
Sample x_{t-1} from the model using Euler method.
Args:
model: The model to sample from.
x_t: The [N x C x ...] tensor of noisy inputs at time t.
t: The current timestep.
t_prev: The previous timestep.
cond: conditional information.
**kwargs: Additional arguments for model inference.
Returns:
a dict containing the following
- 'pred_x_prev': x_{t-1}.
- 'pred_x_0': a prediction of x_0.
"""
# Get model predictions
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
# Euler step: x_{t-1} = x_t - (t - t_prev) * v_t
pred_x_prev = x_t - (t - t_prev) * pred_v
return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
@torch.no_grad()
def sample(
self,
model,
noise,
cond: Optional[Any] = None,
steps: int = 50,
rescale_t: float = 1.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
sample = noise
# Create a linearly spaced timestep sequence from 1 to 0
t_seq = np.linspace(1, 0, steps + 1)
# Apply rescaling to timesteps if needed
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
# Create pairs of consecutive timesteps
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
# Initialize return dictionary
ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
# print(f"shape of cond: {cond.shape}") # shape of cond: torch.Size([4, 1374, 1024])
# Perform Euler sampling steps
for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
sample = out.pred_x_prev
ret.pred_x_t.append(out.pred_x_prev)
ret.pred_x_0.append(out.pred_x_0)
ret.samples = sample
return ret
class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
This class adds classifier-free guidance to the Euler sampler, enabling conditional
generation with guidance strength control.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
cfg_strength: float = 3.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
cfg_strength: The strength of classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
# Call the parent sample method with CFG parameters
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
This class extends the Euler sampler with both classifier-free guidance and the ability
to specify timestep intervals where guidance is applied.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
cfg_strength: float = 3.0,
cfg_interval: Tuple[float, float] = (0.0, 1.0),
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
cfg_strength: The strength of classifier-free guidance.
cfg_interval: The interval for classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
# Call the parent sample method with CFG and interval parameters
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)