Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import Union | |
def deterministic_ddpm_step( | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
scheduler, | |
): | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
""" | |
t = timestep | |
prev_t = scheduler.previous_timestep(t) | |
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ | |
"learned", | |
"learned_range", | |
]: | |
model_output, predicted_variance = torch.split( | |
model_output, sample.shape[1], dim=1 | |
) | |
else: | |
predicted_variance = None | |
# 1. compute alphas, betas | |
alpha_prod_t = scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one | |
) | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
current_beta_t = 1 - current_alpha_t | |
# 2. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | |
if scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = ( | |
sample - beta_prod_t ** (0.5) * model_output | |
) / alpha_prod_t ** (0.5) | |
elif scheduler.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
elif scheduler.config.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t**0.5) * sample - ( | |
beta_prod_t**0.5 | |
) * model_output | |
else: | |
raise ValueError( | |
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | |
" `v_prediction` for the DDPMScheduler." | |
) | |
# 3. Clip or threshold "predicted x_0" | |
if scheduler.config.thresholding: | |
pred_original_sample = scheduler._threshold_sample(pred_original_sample) | |
elif scheduler.config.clip_sample: | |
pred_original_sample = pred_original_sample.clamp( | |
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range | |
) | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
coef_D = current_sample_coeff * (beta_prod_t ** (0.5)) ## it is equal to coef_D | |
pred_prev_sample = (alpha_prod_t_prev ** (0.5) * pred_original_sample) + ( | |
coef_D * model_output | |
) | |
return pred_prev_sample, coef_D | |