import torch from diffusers.utils.torch_utils import randn_tensor ''' Customized Step Function step on texture ''' @torch.no_grad() def step_tex_sync( scheduler, uvp, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, texture: None, generator=None, return_dict: bool = True, guidance_scale = 1, main_views = [], hires_original_views = True, exp=None, cos_weighted=True ): t = timestep prev_t = scheduler.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[t] alpha_prod_t_prev = scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if scheduler.config.thresholding: pred_original_sample = scheduler._threshold_sample(pred_original_sample) elif scheduler.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t ''' Add multidiffusion here ''' if texture is None: sample_views = [view for view in sample] sample_views, texture, _ = uvp.bake_texture(views=sample_views, main_views=main_views, exp=exp) sample_views = torch.stack(sample_views, axis=0)[:,:-1,...] original_views = [view for view in pred_original_sample] original_views, original_tex, visibility_weights = uvp.bake_texture(views=original_views, main_views=main_views, exp=exp) uvp.set_texture_map(original_tex) original_views = uvp.render_textured_views() original_views = torch.stack(original_views, axis=0)[:,:-1,...] # 5. Compute predicted previous sample ยต_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample prev_tex = pred_original_sample_coeff * original_tex + current_sample_coeff * texture # 6. Add noise variance = 0 if predicted_variance is not None: variance_views = [view for view in predicted_variance] variance_views, variance_tex, visibility_weights = uvp.bake_texture(views=variance_views, main_views=main_views, cos_weighted=cos_weighted, exp=exp) variance_views = torch.stack(variance_views, axis=0)[:,:-1,...] else: variance_tex = None if t > 0: device = texture.device variance_noise = randn_tensor( texture.shape, generator=generator, device=device, dtype=texture.dtype ) if scheduler.variance_type == "fixed_small_log": variance = scheduler._get_variance(t, predicted_variance=variance_tex) * variance_noise elif scheduler.variance_type == "learned_range": variance = scheduler._get_variance(t, predicted_variance=variance_tex) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (scheduler._get_variance(t, predicted_variance=variance_tex) ** 0.5) * variance_noise prev_tex = prev_tex + variance uvp.set_texture_map(prev_tex) prev_views = uvp.render_textured_views() pred_prev_sample = torch.clone(sample) for i, view in enumerate(prev_views): pred_prev_sample[i] = view[:-1] masks = [view[-1:] for view in prev_views] return {"prev_sample": pred_prev_sample, "pred_original_sample":pred_original_sample, "prev_tex": prev_tex} if not return_dict: return pred_prev_sample, pred_original_sample pass