Spaces:
Running
on
T4
Running
on
T4
from diffusionsfm.inference.ddim import inference_ddim | |
from diffusionsfm.utils.rays import ( | |
Rays, | |
rays_to_cameras, | |
rays_to_cameras_homography, | |
) | |
def predict_cameras( | |
model, | |
images, | |
device, | |
crop_parameters=None, | |
stop_iteration=None, | |
num_patches_x=16, | |
num_patches_y=16, | |
additional_timesteps=(), | |
calculate_intrinsics=False, | |
max_num_images=8, | |
mode=None, | |
return_rays=False, | |
use_homogeneous=False, | |
seed=0, | |
): | |
""" | |
Args: | |
images (torch.Tensor): (N, C, H, W) | |
crop_parameters (torch.Tensor): (N, 4) or None | |
""" | |
if calculate_intrinsics: | |
ray_to_cam = rays_to_cameras_homography | |
else: | |
ray_to_cam = rays_to_cameras | |
get_spatial_rays = Rays.from_spatial | |
rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim( | |
model, | |
images.unsqueeze(0), | |
device, | |
crop_parameters=crop_parameters.unsqueeze(0), | |
pbar=False, | |
stop_iteration=stop_iteration, | |
eta=[1, 0], | |
num_inference_steps=100, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
visualize=True, | |
max_num_images=max_num_images, | |
) | |
spatial_rays = get_spatial_rays( | |
rays_final[0], | |
mode=mode, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
use_homogeneous=use_homogeneous, | |
) | |
pred_cam = ray_to_cam( | |
spatial_rays, | |
crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
depth_resolution=model.depth_resolution, | |
average_centers=True, | |
directions_from_averaged_center=True, | |
) | |
additional_predictions = [] | |
for t in additional_timesteps: | |
ray = pred_intermediate[t] | |
ray = get_spatial_rays( | |
ray[0], | |
mode=mode, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
use_homogeneous=use_homogeneous, | |
) | |
cam = ray_to_cam( | |
ray, | |
crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
average_centers=True, | |
directions_from_averaged_center=True, | |
) | |
if return_rays: | |
cam = (cam, ray) | |
additional_predictions.append(cam) | |
if return_rays: | |
return (pred_cam, spatial_rays), additional_predictions | |
return pred_cam, additional_predictions, spatial_rays |