import torch import random import numpy as np from tqdm.auto import tqdm from diffusionsfm.utils.rays import compute_ndc_coordinates def inference_ddim( model, images, device, crop_parameters=None, eta=0, num_inference_steps=100, pbar=True, stop_iteration=None, num_patches_x=16, num_patches_y=16, visualize=False, max_num_images=8, seed=0, ): """ Implements DDIM-style inference. To get multiple samples, batch the images multiple times. Args: model: Ray Diffuser. images (torch.Tensor): (B, N, C, H, W). patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground truth (B, N, P, 6). eta (float, optional): Stochasticity coefficient. 0 is completely deterministic, 1 is equivalent to DDPM. (Default: 0) num_inference_steps (int, optional): Number of inference steps. (Default: 100) pbar (bool, optional): Whether to show progress bar. (Default: True) """ timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps) batch_size = images.shape[0] num_images = images.shape[1] if isinstance(eta, list): eta_0, eta_1 = float(eta[0]), float(eta[1]) else: eta_0, eta_1 = 0, 0 # Fixing seed if seed is not None: torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) with torch.no_grad(): x_tau = torch.randn( batch_size, num_images, model.ray_out if hasattr(model, "ray_out") else model.ray_dim, num_patches_x, num_patches_y, device=device, ) if visualize: x_taus = [x_tau] all_pred = [] noise_samples = [] image_features = model.feature_extractor(images, autoresize=True) if model.append_ndc: ndc_coordinates = compute_ndc_coordinates( crop_parameters=crop_parameters, no_crop_param_device="cpu", num_patches_x=model.width, num_patches_y=model.width, distortion_coeffs=None, )[..., :2].to(device) ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3) else: ndc_coordinates = None if stop_iteration is None: loop = range(len(timesteps)) else: loop = range(len(timesteps) - stop_iteration + 1) loop = tqdm(loop) if pbar else loop for t in loop: tau = timesteps[t] if tau > 0 and eta_1 > 0: z = torch.randn( batch_size, num_images, model.ray_out if hasattr(model, "ray_out") else model.ray_dim, num_patches_x, num_patches_y, device=device, ) else: z = 0 alpha = model.noise_scheduler.alphas_cumprod[tau] if tau > 0: tau_prev = timesteps[t + 1] alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev] else: alpha_prev = torch.tensor(1.0, device=device).float() sigma_t = ( torch.sqrt((1 - alpha_prev) / (1 - alpha)) * torch.sqrt(1 - alpha / alpha_prev) ) if num_images > max_num_images: eps_pred = torch.zeros_like(x_tau) noise_sample = torch.zeros_like(x_tau) # Randomly split image indices (excluding index 0), then prepend 0 to each split indices_split = torch.split( torch.randperm(num_images - 1) + 1, max_num_images - 1 ) for indices in indices_split: indices = torch.cat((torch.tensor([0]), indices)) # Ensure index 0 is always included eps_pred_ind, noise_sample_ind = model( features=image_features[:, indices], rays_noisy=x_tau[:, indices], t=int(tau), ndc_coordinates=ndc_coordinates[:, indices], indices=indices, ) eps_pred[:, indices] += eps_pred_ind if noise_sample_ind is not None: noise_sample[:, indices] += noise_sample_ind # Average over splits for the shared reference index (0) eps_pred[:, 0] /= len(indices_split) noise_sample[:, 0] /= len(indices_split) else: eps_pred, noise_sample = model( features=image_features, rays_noisy=x_tau, t=int(tau), ndc_coordinates=ndc_coordinates, ) if model.use_homogeneous: p1 = eps_pred[:, :, :4] p2 = eps_pred[:, :, 4:] c1 = torch.linalg.norm(p1, dim=2, keepdim=True) c2 = torch.linalg.norm(p2, dim=2, keepdim=True) eps_pred[:, :, :4] = p1 / c1 eps_pred[:, :, 4:] = p2 / c2 if visualize: all_pred.append(eps_pred.clone()) noise_samples.append(noise_sample) # TODO: Can simplify this a lot x0_pred = eps_pred.clone() eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt( 1 - alpha ) dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred noise = eta_1 * sigma_t * z new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise x_tau = new_x_tau if visualize: x_taus.append(x_tau.detach().clone()) if visualize: return x_tau, x_taus, all_pred, noise_samples return x_tau