qitaoz commited on
Commit
5b836a4
·
verified ·
1 Parent(s): e6521d1

Delete diffusionsfm/inference/ddim.py

Browse files
Files changed (1) hide show
  1. diffusionsfm/inference/ddim.py +0 -145
diffusionsfm/inference/ddim.py DELETED
@@ -1,145 +0,0 @@
1
- import torch
2
- import random
3
- import numpy as np
4
- from tqdm.auto import tqdm
5
-
6
- from diffusionsfm.utils.rays import compute_ndc_coordinates
7
-
8
-
9
- def inference_ddim(
10
- model,
11
- images,
12
- device,
13
- crop_parameters=None,
14
- eta=0,
15
- num_inference_steps=100,
16
- pbar=True,
17
- num_patches_x=16,
18
- num_patches_y=16,
19
- visualize=False,
20
- seed=0,
21
- ):
22
- """
23
- Implements DDIM-style inference.
24
-
25
- To get multiple samples, batch the images multiple times.
26
-
27
- Args:
28
- model: Ray Diffuser.
29
- images (torch.Tensor): (B, N, C, H, W).
30
- patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground
31
- truth (B, N, P, 6).
32
- eta (float, optional): Stochasticity coefficient. 0 is completely deterministic,
33
- 1 is equivalent to DDPM. (Default: 0)
34
- num_inference_steps (int, optional): Number of inference steps. (Default: 100)
35
- pbar (bool, optional): Whether to show progress bar. (Default: True)
36
- """
37
- timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps)
38
- batch_size = images.shape[0]
39
- num_images = images.shape[1]
40
-
41
- if isinstance(eta, list):
42
- eta_0, eta_1 = float(eta[0]), float(eta[1])
43
- else:
44
- eta_0, eta_1 = 0, 0
45
-
46
- # Fixing seed
47
- if seed is not None:
48
- torch.manual_seed(seed)
49
- random.seed(seed)
50
- np.random.seed(seed)
51
-
52
- with torch.no_grad():
53
- x_tau = torch.randn(
54
- batch_size,
55
- num_images,
56
- model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
57
- num_patches_x,
58
- num_patches_y,
59
- device=device,
60
- )
61
-
62
- if visualize:
63
- x_taus = [x_tau]
64
- all_pred = []
65
- noise_samples = []
66
-
67
- image_features = model.feature_extractor(images, autoresize=True)
68
-
69
- if model.append_ndc:
70
- ndc_coordinates = compute_ndc_coordinates(
71
- crop_parameters=crop_parameters,
72
- no_crop_param_device="cpu",
73
- num_patches_x=model.width,
74
- num_patches_y=model.width,
75
- distortion_coeffs=None,
76
- )[..., :2].to(device)
77
- ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3)
78
- else:
79
- ndc_coordinates = None
80
-
81
- loop = tqdm(range(len(timesteps))) if pbar else range(len(timesteps))
82
- for t in loop:
83
- tau = timesteps[t]
84
-
85
- if tau > 0 and eta_1 > 0:
86
- z = torch.randn(
87
- batch_size,
88
- num_images,
89
- model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
90
- num_patches_x,
91
- num_patches_y,
92
- device=device,
93
- )
94
- else:
95
- z = 0
96
-
97
- alpha = model.noise_scheduler.alphas_cumprod[tau]
98
- if tau > 0:
99
- tau_prev = timesteps[t + 1]
100
- alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev]
101
- else:
102
- alpha_prev = torch.tensor(1.0, device=device).float()
103
-
104
- sigma_t = (
105
- torch.sqrt((1 - alpha_prev) / (1 - alpha))
106
- * torch.sqrt(1 - alpha / alpha_prev)
107
- )
108
-
109
- eps_pred, noise_sample = model(
110
- features=image_features,
111
- rays_noisy=x_tau,
112
- t=int(tau),
113
- ndc_coordinates=ndc_coordinates,
114
- )
115
-
116
- if model.use_homogeneous:
117
- p1 = eps_pred[:, :, :4]
118
- p2 = eps_pred[:, :, 4:]
119
-
120
- c1 = torch.linalg.norm(p1, dim=2, keepdim=True)
121
- c2 = torch.linalg.norm(p2, dim=2, keepdim=True)
122
- eps_pred[:, :, :4] = p1 / c1
123
- eps_pred[:, :, 4:] = p2 / c2
124
-
125
- if visualize:
126
- all_pred.append(eps_pred.clone())
127
- noise_samples.append(noise_sample)
128
-
129
- # TODO: Can simplify this a lot
130
- x0_pred = eps_pred.clone()
131
- eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt(
132
- 1 - alpha
133
- )
134
-
135
- dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred
136
- noise = eta_1 * sigma_t * z
137
-
138
- new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise
139
- x_tau = new_x_tau
140
-
141
- if visualize:
142
- x_taus.append(x_tau.detach().clone())
143
- if visualize:
144
- return x_tau, x_taus, all_pred, noise_samples
145
- return x_tau