|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
from einops import einsum, rearrange |
|
from jaxtyping import Float |
|
from plyfile import PlyData, PlyElement |
|
from scipy.spatial.transform import Rotation as R |
|
from torch import Tensor |
|
|
|
|
|
def construct_list_of_attributes(num_rest: int) -> list[str]: |
|
attributes = ["x", "y", "z", "nx", "ny", "nz"] |
|
for i in range(3): |
|
attributes.append(f"f_dc_{i}") |
|
for i in range(num_rest): |
|
attributes.append(f"f_rest_{i}") |
|
attributes.append("opacity") |
|
for i in range(3): |
|
attributes.append(f"scale_{i}") |
|
for i in range(4): |
|
attributes.append(f"rot_{i}") |
|
return attributes |
|
|
|
|
|
def export_ply( |
|
means: Float[Tensor, "gaussian 3"], |
|
scales: Float[Tensor, "gaussian 3"], |
|
rotations: Float[Tensor, "gaussian 4"], |
|
harmonics: Float[Tensor, "gaussian 3 d_sh"], |
|
opacities: Float[Tensor, " gaussian"], |
|
path: Path, |
|
shift_and_scale: bool = False, |
|
save_sh_dc_only: bool = True, |
|
): |
|
if shift_and_scale: |
|
|
|
means = means - means.median(dim=0).values |
|
|
|
|
|
scale_factor = means.abs().quantile(0.95, dim=0).max() |
|
means = means / scale_factor |
|
scales = scales / scale_factor |
|
|
|
|
|
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
|
rotations = R.from_matrix(rotations).as_quat() |
|
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
|
rotations = np.stack((w, x, y, z), axis=-1) |
|
|
|
|
|
|
|
f_dc = harmonics[..., 0] |
|
f_rest = harmonics[..., 1:].flatten(start_dim=1) |
|
|
|
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] |
|
elements = np.empty(means.shape[0], dtype=dtype_full) |
|
attributes = [ |
|
means.detach().cpu().numpy(), |
|
torch.zeros_like(means).detach().cpu().numpy(), |
|
f_dc.detach().cpu().contiguous().numpy(), |
|
f_rest.detach().cpu().contiguous().numpy(), |
|
opacities[..., None].detach().cpu().numpy(), |
|
scales.log().detach().cpu().numpy(), |
|
rotations, |
|
] |
|
if save_sh_dc_only: |
|
|
|
attributes.pop(3) |
|
|
|
attributes = np.concatenate(attributes, axis=1) |
|
elements[:] = list(map(tuple, attributes)) |
|
path.parent.mkdir(exist_ok=True, parents=True) |
|
PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
|
|