cora / src /ddpm_step.py
armikaeili's picture
code added
79c5088
import torch
from typing import Union
def deterministic_ddpm_step(
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
scheduler,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
"""
t = timestep
prev_t = scheduler.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(
model_output, sample.shape[1], dim=1
)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if scheduler.config.prediction_type == "epsilon":
pred_original_sample = (
sample - beta_prod_t ** (0.5) * model_output
) / alpha_prod_t ** (0.5)
elif scheduler.config.prediction_type == "sample":
pred_original_sample = model_output
elif scheduler.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (
beta_prod_t**0.5
) * model_output
else:
raise ValueError(
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for the DDPMScheduler."
)
# 3. Clip or threshold "predicted x_0"
if scheduler.config.thresholding:
pred_original_sample = scheduler._threshold_sample(pred_original_sample)
elif scheduler.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
)
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
coef_D = current_sample_coeff * (beta_prod_t ** (0.5)) ## it is equal to coef_D
pred_prev_sample = (alpha_prod_t_prev ** (0.5) * pred_original_sample) + (
coef_D * model_output
)
return pred_prev_sample, coef_D