qitaoz commited on
Commit
5e82535
·
verified ·
1 Parent(s): 5b836a4

Upload 2 files

Browse files
diffusionsfm/inference/ddim.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from tqdm.auto import tqdm
5
+
6
+ from diffusionsfm.utils.rays import compute_ndc_coordinates
7
+
8
+
9
+ def inference_ddim(
10
+ model,
11
+ images,
12
+ device,
13
+ crop_parameters=None,
14
+ eta=0,
15
+ num_inference_steps=100,
16
+ pbar=True,
17
+ stop_iteration=None,
18
+ num_patches_x=16,
19
+ num_patches_y=16,
20
+ visualize=False,
21
+ max_num_images=8,
22
+ seed=0,
23
+ ):
24
+ """
25
+ Implements DDIM-style inference.
26
+
27
+ To get multiple samples, batch the images multiple times.
28
+
29
+ Args:
30
+ model: Ray Diffuser.
31
+ images (torch.Tensor): (B, N, C, H, W).
32
+ patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground
33
+ truth (B, N, P, 6).
34
+ eta (float, optional): Stochasticity coefficient. 0 is completely deterministic,
35
+ 1 is equivalent to DDPM. (Default: 0)
36
+ num_inference_steps (int, optional): Number of inference steps. (Default: 100)
37
+ pbar (bool, optional): Whether to show progress bar. (Default: True)
38
+ """
39
+ timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps)
40
+ batch_size = images.shape[0]
41
+ num_images = images.shape[1]
42
+
43
+ if isinstance(eta, list):
44
+ eta_0, eta_1 = float(eta[0]), float(eta[1])
45
+ else:
46
+ eta_0, eta_1 = 0, 0
47
+
48
+ # Fixing seed
49
+ if seed is not None:
50
+ torch.manual_seed(seed)
51
+ random.seed(seed)
52
+ np.random.seed(seed)
53
+
54
+ with torch.no_grad():
55
+ x_tau = torch.randn(
56
+ batch_size,
57
+ num_images,
58
+ model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
59
+ num_patches_x,
60
+ num_patches_y,
61
+ device=device,
62
+ )
63
+
64
+ if visualize:
65
+ x_taus = [x_tau]
66
+ all_pred = []
67
+ noise_samples = []
68
+
69
+ image_features = model.feature_extractor(images, autoresize=True)
70
+
71
+ if model.append_ndc:
72
+ ndc_coordinates = compute_ndc_coordinates(
73
+ crop_parameters=crop_parameters,
74
+ no_crop_param_device="cpu",
75
+ num_patches_x=model.width,
76
+ num_patches_y=model.width,
77
+ distortion_coeffs=None,
78
+ )[..., :2].to(device)
79
+ ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3)
80
+ else:
81
+ ndc_coordinates = None
82
+
83
+ if stop_iteration is None:
84
+ loop = range(len(timesteps))
85
+ else:
86
+ loop = range(len(timesteps) - stop_iteration + 1)
87
+ loop = tqdm(loop) if pbar else loop
88
+
89
+ for t in loop:
90
+ tau = timesteps[t]
91
+
92
+ if tau > 0 and eta_1 > 0:
93
+ z = torch.randn(
94
+ batch_size,
95
+ num_images,
96
+ model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
97
+ num_patches_x,
98
+ num_patches_y,
99
+ device=device,
100
+ )
101
+ else:
102
+ z = 0
103
+
104
+ alpha = model.noise_scheduler.alphas_cumprod[tau]
105
+ if tau > 0:
106
+ tau_prev = timesteps[t + 1]
107
+ alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev]
108
+ else:
109
+ alpha_prev = torch.tensor(1.0, device=device).float()
110
+
111
+ sigma_t = (
112
+ torch.sqrt((1 - alpha_prev) / (1 - alpha))
113
+ * torch.sqrt(1 - alpha / alpha_prev)
114
+ )
115
+
116
+ if num_images > max_num_images:
117
+ eps_pred = torch.zeros_like(x_tau)
118
+ noise_sample = torch.zeros_like(x_tau)
119
+
120
+ # Randomly split image indices (excluding index 0), then prepend 0 to each split
121
+ indices_split = torch.split(
122
+ torch.randperm(num_images - 1) + 1, max_num_images - 1
123
+ )
124
+
125
+ for indices in indices_split:
126
+ indices = torch.cat((torch.tensor([0]), indices)) # Ensure index 0 is always included
127
+
128
+ eps_pred_ind, noise_sample_ind = model(
129
+ features=image_features[:, indices],
130
+ rays_noisy=x_tau[:, indices],
131
+ t=int(tau),
132
+ ndc_coordinates=ndc_coordinates[:, indices],
133
+ indices=indices,
134
+ )
135
+
136
+ eps_pred[:, indices] += eps_pred_ind
137
+
138
+ if noise_sample_ind is not None:
139
+ noise_sample[:, indices] += noise_sample_ind
140
+
141
+ # Average over splits for the shared reference index (0)
142
+ eps_pred[:, 0] /= len(indices_split)
143
+ noise_sample[:, 0] /= len(indices_split)
144
+ else:
145
+ eps_pred, noise_sample = model(
146
+ features=image_features,
147
+ rays_noisy=x_tau,
148
+ t=int(tau),
149
+ ndc_coordinates=ndc_coordinates,
150
+ )
151
+
152
+ if model.use_homogeneous:
153
+ p1 = eps_pred[:, :, :4]
154
+ p2 = eps_pred[:, :, 4:]
155
+
156
+ c1 = torch.linalg.norm(p1, dim=2, keepdim=True)
157
+ c2 = torch.linalg.norm(p2, dim=2, keepdim=True)
158
+ eps_pred[:, :, :4] = p1 / c1
159
+ eps_pred[:, :, 4:] = p2 / c2
160
+
161
+ if visualize:
162
+ all_pred.append(eps_pred.clone())
163
+ noise_samples.append(noise_sample)
164
+
165
+ # TODO: Can simplify this a lot
166
+ x0_pred = eps_pred.clone()
167
+ eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt(
168
+ 1 - alpha
169
+ )
170
+
171
+ dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred
172
+ noise = eta_1 * sigma_t * z
173
+
174
+ new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise
175
+ x_tau = new_x_tau
176
+
177
+ if visualize:
178
+ x_taus.append(x_tau.detach().clone())
179
+ if visualize:
180
+ return x_tau, x_taus, all_pred, noise_samples
181
+ return x_tau
diffusionsfm/inference/predict.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusionsfm.inference.ddim import inference_ddim
2
+ from diffusionsfm.utils.rays import (
3
+ Rays,
4
+ rays_to_cameras,
5
+ rays_to_cameras_homography,
6
+ )
7
+
8
+
9
+ def predict_cameras(
10
+ model,
11
+ images,
12
+ device,
13
+ crop_parameters=None,
14
+ stop_iteration=None,
15
+ num_patches_x=16,
16
+ num_patches_y=16,
17
+ additional_timesteps=(),
18
+ calculate_intrinsics=False,
19
+ max_num_images=8,
20
+ mode=None,
21
+ return_rays=False,
22
+ use_homogeneous=False,
23
+ seed=0,
24
+ ):
25
+ """
26
+ Args:
27
+ images (torch.Tensor): (N, C, H, W)
28
+ crop_parameters (torch.Tensor): (N, 4) or None
29
+ """
30
+ if calculate_intrinsics:
31
+ ray_to_cam = rays_to_cameras_homography
32
+ else:
33
+ ray_to_cam = rays_to_cameras
34
+
35
+ get_spatial_rays = Rays.from_spatial
36
+
37
+ rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
38
+ model,
39
+ images.unsqueeze(0),
40
+ device,
41
+ crop_parameters=crop_parameters.unsqueeze(0),
42
+ pbar=False,
43
+ stop_iteration=stop_iteration,
44
+ eta=[1, 0],
45
+ num_inference_steps=100,
46
+ num_patches_x=num_patches_x,
47
+ num_patches_y=num_patches_y,
48
+ visualize=True,
49
+ max_num_images=max_num_images,
50
+ )
51
+
52
+ spatial_rays = get_spatial_rays(
53
+ rays_final[0],
54
+ mode=mode,
55
+ num_patches_x=num_patches_x,
56
+ num_patches_y=num_patches_y,
57
+ use_homogeneous=use_homogeneous,
58
+ )
59
+
60
+ pred_cam = ray_to_cam(
61
+ spatial_rays,
62
+ crop_parameters,
63
+ num_patches_x=num_patches_x,
64
+ num_patches_y=num_patches_y,
65
+ depth_resolution=model.depth_resolution,
66
+ average_centers=True,
67
+ directions_from_averaged_center=True,
68
+ )
69
+
70
+ additional_predictions = []
71
+ for t in additional_timesteps:
72
+ ray = pred_intermediate[t]
73
+
74
+ ray = get_spatial_rays(
75
+ ray[0],
76
+ mode=mode,
77
+ num_patches_x=num_patches_x,
78
+ num_patches_y=num_patches_y,
79
+ use_homogeneous=use_homogeneous,
80
+ )
81
+
82
+ cam = ray_to_cam(
83
+ ray,
84
+ crop_parameters,
85
+ num_patches_x=num_patches_x,
86
+ num_patches_y=num_patches_y,
87
+ average_centers=True,
88
+ directions_from_averaged_center=True,
89
+ )
90
+ if return_rays:
91
+ cam = (cam, ray)
92
+ additional_predictions.append(cam)
93
+
94
+ if return_rays:
95
+ return (pred_cam, spatial_rays), additional_predictions
96
+ return pred_cam, additional_predictions, spatial_rays