File size: 10,719 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
"""
Flow Euler Samplers for Generative Models
This file implements samplers for flow-matching generative models using the Euler integration method.
It contains three main sampler classes:
1. FlowEulerSampler: Base implementation of Euler sampling for flow-matching models
2. FlowEulerCfgSampler: Adds classifier-free guidance to the Euler sampler
3. FlowEulerGuidanceIntervalSampler: Enhances the sampler with both classifier-free guidance and guidance intervals
Flow-matching models define continuous paths from noise to data, and these samplers implement
ODE solvers (specifically Euler method) to follow these paths and generate samples.
"""
from typing import *
import torch
import numpy as np
from tqdm import tqdm
from easydict import EasyDict as edict
from .base import Sampler
from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
class FlowEulerSampler(Sampler):
"""
Generate samples from a flow-matching model using Euler sampling.
Args:
sigma_min: The minimum scale of noise in flow.
"""
def __init__(
self,
sigma_min: float,
):
# sigma_min controls the minimum noise level in the flow
self.sigma_min = sigma_min
def _eps_to_xstart(self, x_t, t, eps):
"""
Convert noise prediction (epsilon) to predicted clean data (x_0).
Args:
x_t: Current noisy tensor at timestep t
t: Current timestep
eps: Predicted noise
Returns:
Predicted clean data x_0
"""
assert x_t.shape == eps.shape
return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
def _xstart_to_eps(self, x_t, t, x_0):
"""
Convert predicted clean data (x_0) to noise prediction (epsilon).
Args:
x_t: Current noisy tensor at timestep t
t: Current timestep
x_0: Predicted clean data
Returns:
Implied noise prediction epsilon
"""
assert x_t.shape == x_0.shape
return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
def _v_to_xstart_eps(self, x_t, t, v):
"""
Convert velocity prediction (v) to predicted clean data (x_0) and noise (epsilon).
Args:
x_t: Current noisy tensor at timestep t
t: Current timestep
v: Predicted velocity
Returns:
Tuple of (x_0, epsilon) derived from velocity
"""
assert x_t.shape == v.shape
eps = (1 - t) * v + x_t
x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
return x_0, eps
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
"""
Run inference with the model.
Args:
model: The flow model
x_t: Current noisy tensor at timestep t
t: Current timestep (will be scaled by 1000)
cond: Conditional information
kwargs: Additional arguments for model
Returns:
Model's predicted velocity
"""
# Scale timestep by 1000 for model input
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
# Broadcast single condition to match batch size if needed
# print(f"cond shape: {cond.shape}")
if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1:
cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1)))
# print(f"cond shape after repeat: {cond.shape}")
return model(x_t, t, cond, **kwargs)
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
"""
Get model predictions and convert to various formats.
Args:
model: The flow model
x_t: Current noisy tensor at timestep t
t: Current timestep
cond: Conditional information
kwargs: Additional arguments for model
Returns:
Tuple of (x_0, epsilon, velocity) predictions
"""
pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
return pred_x_0, pred_eps, pred_v
@torch.no_grad()
def sample_once(
self,
model,
x_t,
t: float,
t_prev: float,
cond: Optional[Any] = None,
**kwargs
):
"""
Sample x_{t-1} from the model using Euler method.
Args:
model: The model to sample from.
x_t: The [N x C x ...] tensor of noisy inputs at time t.
t: The current timestep.
t_prev: The previous timestep.
cond: conditional information.
**kwargs: Additional arguments for model inference.
Returns:
a dict containing the following
- 'pred_x_prev': x_{t-1}.
- 'pred_x_0': a prediction of x_0.
"""
# Get model predictions
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
# Euler step: x_{t-1} = x_t - (t - t_prev) * v_t
pred_x_prev = x_t - (t - t_prev) * pred_v
return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
@torch.no_grad()
def sample(
self,
model,
noise,
cond: Optional[Any] = None,
steps: int = 50,
rescale_t: float = 1.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
sample = noise
# Create a linearly spaced timestep sequence from 1 to 0
t_seq = np.linspace(1, 0, steps + 1)
# Apply rescaling to timesteps if needed
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
# Create pairs of consecutive timesteps
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
# Initialize return dictionary
ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
# print(f"shape of cond: {cond.shape}") # shape of cond: torch.Size([4, 1374, 1024])
# Perform Euler sampling steps
for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
sample = out.pred_x_prev
ret.pred_x_t.append(out.pred_x_prev)
ret.pred_x_0.append(out.pred_x_0)
ret.samples = sample
return ret
class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
This class adds classifier-free guidance to the Euler sampler, enabling conditional
generation with guidance strength control.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
cfg_strength: float = 3.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
cfg_strength: The strength of classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
# Call the parent sample method with CFG parameters
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
This class extends the Euler sampler with both classifier-free guidance and the ability
to specify timestep intervals where guidance is applied.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
cfg_strength: float = 3.0,
cfg_interval: Tuple[float, float] = (0.0, 1.0),
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
cfg_strength: The strength of classifier-free guidance.
cfg_interval: The interval for classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
# Call the parent sample method with CFG and interval parameters
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
|