File size: 2,683 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from pathlib import Path
import numpy as np
import torch
from einops import einsum, rearrange
from jaxtyping import Float
from plyfile import PlyData, PlyElement
from scipy.spatial.transform import Rotation as R
from torch import Tensor
def construct_list_of_attributes(num_rest: int) -> list[str]:
attributes = ["x", "y", "z", "nx", "ny", "nz"]
for i in range(3):
attributes.append(f"f_dc_{i}")
for i in range(num_rest):
attributes.append(f"f_rest_{i}")
attributes.append("opacity")
for i in range(3):
attributes.append(f"scale_{i}")
for i in range(4):
attributes.append(f"rot_{i}")
return attributes
def export_ply(
means: Float[Tensor, "gaussian 3"],
scales: Float[Tensor, "gaussian 3"],
rotations: Float[Tensor, "gaussian 4"],
harmonics: Float[Tensor, "gaussian 3 d_sh"],
opacities: Float[Tensor, " gaussian"],
path: Path,
shift_and_scale: bool = False,
save_sh_dc_only: bool = True,
):
if shift_and_scale:
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values
# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor
# Apply the rotation to the Gaussian rotations.
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
rotations = R.from_matrix(rotations).as_quat()
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
rotations = np.stack((w, x, y, z), axis=-1)
# Since current model use SH_degree = 4,
# which require large memory to store, we can only save the DC band to save memory.
f_dc = harmonics[..., 0]
f_rest = harmonics[..., 1:].flatten(start_dim=1)
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = [
means.detach().cpu().numpy(),
torch.zeros_like(means).detach().cpu().numpy(),
f_dc.detach().cpu().contiguous().numpy(),
f_rest.detach().cpu().contiguous().numpy(),
opacities[..., None].detach().cpu().numpy(),
scales.log().detach().cpu().numpy(),
rotations,
]
if save_sh_dc_only:
# remove f_rest from attributes
attributes.pop(3)
attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
|