from dataclasses import dataclass from jaxtyping import Float from torch import Tensor @dataclass class Gaussians: means: Float[Tensor, "batch gaussian dim"] covariances: Float[Tensor, "batch gaussian dim dim"] harmonics: Float[Tensor, "batch gaussian 3 d_sh"] opacities: Float[Tensor, "batch gaussian"] scales: Float[Tensor, "batch gaussian 3"] rotations: Float[Tensor, "batch gaussian 4"] # levels: Float[Tensor, "batch gaussian"]