|
import torch |
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
''' |
|
|
|
Customized Step Function |
|
step on texture |
|
''' |
|
@torch.no_grad() |
|
def step_tex_sync( |
|
scheduler, |
|
uvp, |
|
model_output: torch.FloatTensor, |
|
timestep: int, |
|
sample: torch.FloatTensor, |
|
texture: None, |
|
generator=None, |
|
return_dict: bool = True, |
|
guidance_scale = 1, |
|
main_views = [], |
|
hires_original_views = True, |
|
exp=None, |
|
cos_weighted=True |
|
): |
|
t = timestep |
|
|
|
prev_t = scheduler.previous_timestep(t) |
|
|
|
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: |
|
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) |
|
else: |
|
predicted_variance = None |
|
|
|
|
|
alpha_prod_t = scheduler.alphas_cumprod[t] |
|
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
current_alpha_t = alpha_prod_t / alpha_prod_t_prev |
|
current_beta_t = 1 - current_alpha_t |
|
|
|
|
|
|
|
if scheduler.config.prediction_type == "epsilon": |
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
|
elif scheduler.config.prediction_type == "sample": |
|
pred_original_sample = model_output |
|
elif scheduler.config.prediction_type == "v_prediction": |
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
|
else: |
|
raise ValueError( |
|
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" |
|
" `v_prediction` for the DDPMScheduler." |
|
) |
|
|
|
if scheduler.config.thresholding: |
|
pred_original_sample = scheduler._threshold_sample(pred_original_sample) |
|
elif scheduler.config.clip_sample: |
|
pred_original_sample = pred_original_sample.clamp( |
|
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range |
|
) |
|
|
|
|
|
|
|
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t |
|
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t |
|
|
|
''' |
|
Add multidiffusion here |
|
''' |
|
|
|
if texture is None: |
|
sample_views = [view for view in sample] |
|
sample_views, texture, _ = uvp.bake_texture(views=sample_views, main_views=main_views, exp=exp) |
|
sample_views = torch.stack(sample_views, axis=0)[:,:-1,...] |
|
|
|
|
|
original_views = [view for view in pred_original_sample] |
|
original_views, original_tex, visibility_weights = uvp.bake_texture(views=original_views, main_views=main_views, exp=exp) |
|
uvp.set_texture_map(original_tex) |
|
original_views = uvp.render_textured_views() |
|
original_views = torch.stack(original_views, axis=0)[:,:-1,...] |
|
|
|
|
|
|
|
|
|
prev_tex = pred_original_sample_coeff * original_tex + current_sample_coeff * texture |
|
|
|
|
|
variance = 0 |
|
|
|
if predicted_variance is not None: |
|
variance_views = [view for view in predicted_variance] |
|
variance_views, variance_tex, visibility_weights = uvp.bake_texture(views=variance_views, main_views=main_views, cos_weighted=cos_weighted, exp=exp) |
|
variance_views = torch.stack(variance_views, axis=0)[:,:-1,...] |
|
else: |
|
variance_tex = None |
|
|
|
if t > 0: |
|
device = texture.device |
|
variance_noise = randn_tensor( |
|
texture.shape, generator=generator, device=device, dtype=texture.dtype |
|
) |
|
if scheduler.variance_type == "fixed_small_log": |
|
variance = scheduler._get_variance(t, predicted_variance=variance_tex) * variance_noise |
|
elif scheduler.variance_type == "learned_range": |
|
variance = scheduler._get_variance(t, predicted_variance=variance_tex) |
|
variance = torch.exp(0.5 * variance) * variance_noise |
|
else: |
|
variance = (scheduler._get_variance(t, predicted_variance=variance_tex) ** 0.5) * variance_noise |
|
prev_tex = prev_tex + variance |
|
|
|
uvp.set_texture_map(prev_tex) |
|
prev_views = uvp.render_textured_views() |
|
pred_prev_sample = torch.clone(sample) |
|
for i, view in enumerate(prev_views): |
|
pred_prev_sample[i] = view[:-1] |
|
masks = [view[-1:] for view in prev_views] |
|
|
|
return {"prev_sample": pred_prev_sample, "pred_original_sample":pred_original_sample, "prev_tex": prev_tex} |
|
|
|
if not return_dict: |
|
return pred_prev_sample, pred_original_sample |
|
pass |
|
|
|
|