Commit
·
183c3a1
1
Parent(s):
531ea40
add a stablizing trick for steps < 15
Browse filesFormer-commit-id: bf3b8783543bdbfc31721479091e35696baadd13
ldm/models/diffusion/dpm_solver/dpm_solver.py
CHANGED
|
@@ -394,8 +394,8 @@ class DPM_Solver:
|
|
| 394 |
if self.thresholding:
|
| 395 |
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
| 396 |
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 397 |
-
s = expand_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), dims)
|
| 398 |
-
x0 = torch.clamp(x0, -s, s) /
|
| 399 |
return x0
|
| 400 |
|
| 401 |
def model_fn(self, x, t):
|
|
@@ -436,7 +436,7 @@ class DPM_Solver:
|
|
| 436 |
else:
|
| 437 |
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
| 438 |
|
| 439 |
-
def
|
| 440 |
"""
|
| 441 |
Get the order of each step for sampling by the singlestep DPM-Solver.
|
| 442 |
|
|
@@ -458,6 +458,13 @@ class DPM_Solver:
|
|
| 458 |
Args:
|
| 459 |
order: A `int`. The max order for the solver (2 or 3).
|
| 460 |
steps: A `int`. The total number of function evaluations (NFE).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
Returns:
|
| 462 |
orders: A list of the solver order of each step.
|
| 463 |
"""
|
|
@@ -469,20 +476,26 @@ class DPM_Solver:
|
|
| 469 |
orders = [3,] * (K - 1) + [1]
|
| 470 |
else:
|
| 471 |
orders = [3,] * (K - 1) + [2]
|
| 472 |
-
return orders
|
| 473 |
elif order == 2:
|
| 474 |
-
K = steps // 2
|
| 475 |
if steps % 2 == 0:
|
|
|
|
| 476 |
orders = [2,] * K
|
| 477 |
else:
|
| 478 |
-
|
| 479 |
-
|
| 480 |
elif order == 1:
|
| 481 |
-
|
|
|
|
| 482 |
else:
|
| 483 |
raise ValueError("'order' must be '1' or '2' or '3'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
-
def
|
| 486 |
"""
|
| 487 |
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
| 488 |
"""
|
|
@@ -950,8 +963,8 @@ class DPM_Solver:
|
|
| 950 |
return x
|
| 951 |
|
| 952 |
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
| 953 |
-
method='singlestep',
|
| 954 |
-
rtol=0.05,
|
| 955 |
):
|
| 956 |
"""
|
| 957 |
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
|
@@ -1035,8 +1048,19 @@ class DPM_Solver:
|
|
| 1035 |
order: A `int`. The order of DPM-Solver.
|
| 1036 |
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
| 1037 |
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
| 1038 |
-
|
| 1039 |
-
If `
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
|
| 1041 |
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
| 1042 |
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
|
@@ -1067,7 +1091,11 @@ class DPM_Solver:
|
|
| 1067 |
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
| 1068 |
for step in range(order, steps + 1):
|
| 1069 |
vec_t = timesteps[step].expand(x.shape[0])
|
| 1070 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
for i in range(order - 1):
|
| 1072 |
t_prev_list[i] = t_prev_list[i + 1]
|
| 1073 |
model_prev_list[i] = model_prev_list[i + 1]
|
|
@@ -1077,23 +1105,22 @@ class DPM_Solver:
|
|
| 1077 |
model_prev_list[-1] = self.model_fn(x, vec_t)
|
| 1078 |
elif method in ['singlestep', 'singlestep_fixed']:
|
| 1079 |
if method == 'singlestep':
|
| 1080 |
-
orders = self.
|
| 1081 |
-
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 1082 |
elif method == 'singlestep_fixed':
|
| 1083 |
K = steps // order
|
| 1084 |
orders = [order,] * K
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
if
|
| 1096 |
-
x = self.
|
| 1097 |
return x
|
| 1098 |
|
| 1099 |
|
|
|
|
| 394 |
if self.thresholding:
|
| 395 |
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
| 396 |
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 397 |
+
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
| 398 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 399 |
return x0
|
| 400 |
|
| 401 |
def model_fn(self, x, t):
|
|
|
|
| 436 |
else:
|
| 437 |
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
| 438 |
|
| 439 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
| 440 |
"""
|
| 441 |
Get the order of each step for sampling by the singlestep DPM-Solver.
|
| 442 |
|
|
|
|
| 458 |
Args:
|
| 459 |
order: A `int`. The max order for the solver (2 or 3).
|
| 460 |
steps: A `int`. The total number of function evaluations (NFE).
|
| 461 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 462 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 463 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 464 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 465 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 466 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 467 |
+
device: A torch device.
|
| 468 |
Returns:
|
| 469 |
orders: A list of the solver order of each step.
|
| 470 |
"""
|
|
|
|
| 476 |
orders = [3,] * (K - 1) + [1]
|
| 477 |
else:
|
| 478 |
orders = [3,] * (K - 1) + [2]
|
|
|
|
| 479 |
elif order == 2:
|
|
|
|
| 480 |
if steps % 2 == 0:
|
| 481 |
+
K = steps // 2
|
| 482 |
orders = [2,] * K
|
| 483 |
else:
|
| 484 |
+
K = steps // 2 + 1
|
| 485 |
+
orders = [2,] * (K - 1) + [1]
|
| 486 |
elif order == 1:
|
| 487 |
+
K = 1
|
| 488 |
+
orders = [1,] * steps
|
| 489 |
else:
|
| 490 |
raise ValueError("'order' must be '1' or '2' or '3'.")
|
| 491 |
+
if skip_type == 'logSNR':
|
| 492 |
+
# To reproduce the results in DPM-Solver paper
|
| 493 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
| 494 |
+
else:
|
| 495 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)]
|
| 496 |
+
return timesteps_outer, orders
|
| 497 |
|
| 498 |
+
def denoise_to_zero_fn(self, x, s):
|
| 499 |
"""
|
| 500 |
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
| 501 |
"""
|
|
|
|
| 963 |
return x
|
| 964 |
|
| 965 |
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
| 966 |
+
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
| 967 |
+
atol=0.0078, rtol=0.05,
|
| 968 |
):
|
| 969 |
"""
|
| 970 |
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
|
|
|
| 1048 |
order: A `int`. The order of DPM-Solver.
|
| 1049 |
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
| 1050 |
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
| 1051 |
+
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
|
| 1052 |
+
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
|
| 1053 |
+
|
| 1054 |
+
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
|
| 1055 |
+
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
|
| 1056 |
+
for diffusion models sampling by diffusion SDEs for low-resolutional images
|
| 1057 |
+
(such as CIFAR-10). However, we observed that such trick does not matter for
|
| 1058 |
+
high-resolutional images. As it needs an additional NFE, we do not recommend
|
| 1059 |
+
it for high-resolutional images.
|
| 1060 |
+
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
|
| 1061 |
+
Only valid for `method=multistep` and `steps < 15`. We empirically find that
|
| 1062 |
+
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
|
| 1063 |
+
(especially for steps <= 10). So we recommend to set it to be `True`.
|
| 1064 |
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
|
| 1065 |
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
| 1066 |
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
|
|
|
| 1091 |
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
| 1092 |
for step in range(order, steps + 1):
|
| 1093 |
vec_t = timesteps[step].expand(x.shape[0])
|
| 1094 |
+
if lower_order_final and steps < 15:
|
| 1095 |
+
step_order = min(order, steps + 1 - step)
|
| 1096 |
+
else:
|
| 1097 |
+
step_order = order
|
| 1098 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type)
|
| 1099 |
for i in range(order - 1):
|
| 1100 |
t_prev_list[i] = t_prev_list[i + 1]
|
| 1101 |
model_prev_list[i] = model_prev_list[i + 1]
|
|
|
|
| 1105 |
model_prev_list[-1] = self.model_fn(x, vec_t)
|
| 1106 |
elif method in ['singlestep', 'singlestep_fixed']:
|
| 1107 |
if method == 'singlestep':
|
| 1108 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
|
|
|
|
| 1109 |
elif method == 'singlestep_fixed':
|
| 1110 |
K = steps // order
|
| 1111 |
orders = [order,] * K
|
| 1112 |
+
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
| 1113 |
+
for i, order in enumerate(orders):
|
| 1114 |
+
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
|
| 1115 |
+
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device)
|
| 1116 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
| 1117 |
+
vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
|
| 1118 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
| 1119 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
| 1120 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
| 1121 |
+
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
| 1122 |
+
if denoise_to_zero:
|
| 1123 |
+
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
| 1124 |
return x
|
| 1125 |
|
| 1126 |
|
ldm/models/diffusion/dpm_solver/sampler.py
CHANGED
|
@@ -77,6 +77,6 @@ class DPMSolverSampler(object):
|
|
| 77 |
)
|
| 78 |
|
| 79 |
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
| 80 |
-
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2)
|
| 81 |
|
| 82 |
return x.to(device), None
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
| 80 |
+
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
| 81 |
|
| 82 |
return x.to(device), None
|