|
import math |
|
import struct |
|
from io import BytesIO |
|
from typing import Literal, Optional |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def sh2rgb(sh: torch.Tensor) -> torch.Tensor: |
|
"""Convert Sphere Harmonics to RGB |
|
|
|
Args: |
|
sh (torch.Tensor): SH tensor |
|
|
|
Returns: |
|
torch.Tensor: RGB tensor |
|
""" |
|
C0 = 0.28209479177387814 |
|
return sh * C0 + 0.5 |
|
|
|
|
|
def part1by2_vec(x: torch.Tensor) -> torch.Tensor: |
|
"""Interleave bits of x with 0s |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor. Shape (N,) |
|
|
|
Returns: |
|
torch.Tensor: Output tensor. Shape (N,) |
|
""" |
|
|
|
x = x & 0x000003FF |
|
x = (x ^ (x << 16)) & 0xFF0000FF |
|
x = (x ^ (x << 8)) & 0x0300F00F |
|
x = (x ^ (x << 4)) & 0x030C30C3 |
|
x = (x ^ (x << 2)) & 0x09249249 |
|
return x |
|
|
|
|
|
def encode_morton3_vec( |
|
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Compute Morton codes for 3D coordinates |
|
|
|
Args: |
|
x (torch.Tensor): X coordinates. Shape (N,) |
|
y (torch.Tensor): Y coordinates. Shape (N,) |
|
z (torch.Tensor): Z coordinates. Shape (N,) |
|
Returns: |
|
torch.Tensor: Morton codes. Shape (N,) |
|
""" |
|
return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x) |
|
|
|
|
|
def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
|
"""Sort centers based on Morton codes |
|
|
|
Args: |
|
centers (torch.Tensor): Centers. Shape (N, 3) |
|
indices (torch.Tensor): Indices. Shape (N,) |
|
Returns: |
|
torch.Tensor: Sorted indices. Shape (N,) |
|
""" |
|
|
|
min_vals, _ = torch.min(centers, dim=0) |
|
max_vals, _ = torch.max(centers, dim=0) |
|
|
|
|
|
lengths = max_vals - min_vals |
|
lengths[lengths == 0] = 1 |
|
|
|
|
|
scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32) |
|
|
|
|
|
x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2] |
|
|
|
|
|
morton = encode_morton3_vec(x, y, z) |
|
|
|
|
|
sorted_indices = indices[torch.argsort(morton).to(indices.device)] |
|
|
|
return sorted_indices |
|
|
|
|
|
def pack_unorm(value: torch.Tensor, bits: int) -> torch.Tensor: |
|
"""Pack a floating point value into an unsigned integer with a given number of bits. |
|
|
|
Args: |
|
value (torch.Tensor): Floating point value to pack. Shape (N,) |
|
bits (int): Number of bits to pack into. |
|
|
|
Returns: |
|
torch.Tensor: Packed value. Shape (N,) |
|
""" |
|
|
|
t = (1 << bits) - 1 |
|
packed = torch.clamp((value * t + 0.5).floor(), min=0, max=t) |
|
|
|
return packed.to(torch.int64) |
|
|
|
|
|
def pack_111011(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
|
"""Pack three floating point values into a 32-bit integer with 11, 10, and 11 bits. |
|
|
|
Args: |
|
x (torch.Tensor): X component. Shape (N,) |
|
y (torch.Tensor): Y component. Shape (N,) |
|
z (torch.Tensor): Z component. Shape (N,) |
|
Returns: |
|
torch.Tensor: Packed values. Shape (N,) |
|
""" |
|
|
|
packed_x = pack_unorm(x, 11) << 21 |
|
packed_y = pack_unorm(y, 10) << 11 |
|
packed_z = pack_unorm(z, 11) |
|
|
|
|
|
return packed_x | packed_y | packed_z |
|
|
|
|
|
def pack_8888( |
|
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Pack four floating point values into a 32-bit integer with 8 bits each. |
|
|
|
Args: |
|
x (torch.Tensor): X component. Shape (N,) |
|
y (torch.Tensor): Y component. Shape (N,) |
|
z (torch.Tensor): Z component. Shape (N,) |
|
w (torch.Tensor): W component. Shape (N,) |
|
Returns: |
|
torch.Tensor: Packed values. Shape (N,) |
|
""" |
|
|
|
packed_x = pack_unorm(x, 8) << 24 |
|
packed_y = pack_unorm(y, 8) << 16 |
|
packed_z = pack_unorm(z, 8) << 8 |
|
packed_w = pack_unorm(w, 8) |
|
|
|
|
|
return packed_x | packed_y | packed_z | packed_w |
|
|
|
|
|
def pack_rotation(q: torch.Tensor) -> torch.Tensor: |
|
"""Pack a quaternion into a 32-bit integer. |
|
|
|
Args: |
|
q (torch.Tensor): Quaternions. Shape (N, 4) |
|
|
|
Returns: |
|
torch.Tensor: Packed values. Shape (N,) |
|
""" |
|
|
|
|
|
norms = torch.linalg.norm(q, dim=-1, keepdim=True) |
|
q = q / norms |
|
|
|
|
|
largest_components = torch.argmax(torch.abs(q), dim=-1) |
|
|
|
|
|
batch_indices = torch.arange(q.size(0), device=q.device) |
|
largest_values = q[batch_indices, largest_components] |
|
flip_mask = largest_values < 0 |
|
q[flip_mask] *= -1 |
|
|
|
|
|
precomputed_indices = torch.tensor( |
|
[[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]], dtype=torch.long, device=q.device |
|
) |
|
|
|
|
|
pack_indices = precomputed_indices[largest_components] |
|
components_to_pack = q[batch_indices[:, None], pack_indices] |
|
|
|
|
|
norm = math.sqrt(2) * 0.5 |
|
scaled = components_to_pack * norm + 0.5 |
|
packed = pack_unorm(scaled, 10) |
|
|
|
|
|
largest_packed = largest_components.to(torch.int64) << 30 |
|
c0_packed = packed[:, 0] << 20 |
|
c1_packed = packed[:, 1] << 10 |
|
c2_packed = packed[:, 2] |
|
|
|
result = largest_packed | c0_packed | c1_packed | c2_packed |
|
return result |
|
|
|
|
|
def splat2ply_bytes_compressed( |
|
means: torch.Tensor, |
|
scales: torch.Tensor, |
|
quats: torch.Tensor, |
|
opacities: torch.Tensor, |
|
sh0: torch.Tensor, |
|
shN: torch.Tensor, |
|
chunk_max_size: int = 256, |
|
opacity_threshold: float = 1 / 255, |
|
) -> bytes: |
|
"""Return the binary compressed Ply file. Used by Supersplat viewer. |
|
|
|
Args: |
|
means (torch.Tensor): Splat means. Shape (N, 3) |
|
scales (torch.Tensor): Splat scales. Shape (N, 3) |
|
quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
|
opacities (torch.Tensor): Splat opacities. Shape (N,) |
|
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) |
|
shN (torch.Tensor): Spherical harmonics. Shape (N, K*3) |
|
chunk_max_size (int): Maximum number of splats per chunk. Default: 256 |
|
opacity_threshold (float): Opacity threshold. Default: 1 / 255 |
|
|
|
Returns: |
|
bytes: Binary compressed Ply file representing the model. |
|
""" |
|
|
|
|
|
mask = torch.sigmoid(opacities) > opacity_threshold |
|
means = means[mask] |
|
scales = scales[mask] |
|
sh0_colors = sh2rgb(sh0) |
|
sh0_colors = sh0_colors[mask] |
|
shN = shN[mask] |
|
quats = quats[mask] |
|
opacities = opacities[mask] |
|
|
|
num_splats = means.shape[0] |
|
n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0) |
|
indices = torch.arange(num_splats) |
|
indices = sort_centers(means, indices) |
|
|
|
float_properties = [ |
|
"min_x", |
|
"min_y", |
|
"min_z", |
|
"max_x", |
|
"max_y", |
|
"max_z", |
|
"min_scale_x", |
|
"min_scale_y", |
|
"min_scale_z", |
|
"max_scale_x", |
|
"max_scale_y", |
|
"max_scale_z", |
|
"min_r", |
|
"min_g", |
|
"min_b", |
|
"max_r", |
|
"max_g", |
|
"max_b", |
|
] |
|
uint_properties = [ |
|
"packed_position", |
|
"packed_rotation", |
|
"packed_scale", |
|
"packed_color", |
|
] |
|
buffer = BytesIO() |
|
|
|
|
|
buffer.write(b"ply\n") |
|
buffer.write(b"format binary_little_endian 1.0\n") |
|
buffer.write(f"element chunk {n_chunks}\n".encode()) |
|
for prop in float_properties: |
|
buffer.write(f"property float {prop}\n".encode()) |
|
buffer.write(f"element vertex {num_splats}\n".encode()) |
|
for prop in uint_properties: |
|
buffer.write(f"property uint {prop}\n".encode()) |
|
buffer.write(f"element sh {num_splats}\n".encode()) |
|
for j in range(shN.shape[1]): |
|
buffer.write(f"property uchar f_rest_{j}\n".encode()) |
|
buffer.write(b"end_header\n") |
|
|
|
chunk_data = [] |
|
splat_data = [] |
|
sh_data = [] |
|
for chunk_idx in range(n_chunks): |
|
chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats) |
|
chunk_start_idx = chunk_idx * chunk_max_size |
|
splat_idxs = indices[chunk_start_idx:chunk_end_idx] |
|
|
|
|
|
|
|
chunk_means = means[splat_idxs] |
|
min_means = torch.min(chunk_means, dim=0).values |
|
max_means = torch.max(chunk_means, dim=0).values |
|
mean_bounds = torch.cat([min_means, max_means]) |
|
|
|
chunk_scales = scales[splat_idxs] |
|
min_scales = torch.min(chunk_scales, dim=0).values |
|
max_scales = torch.max(chunk_scales, dim=0).values |
|
min_scales = torch.clamp(min_scales, -20, 20) |
|
max_scales = torch.clamp(max_scales, -20, 20) |
|
scale_bounds = torch.cat([min_scales, max_scales]) |
|
|
|
chunk_colors = sh0_colors[splat_idxs] |
|
min_colors = torch.min(chunk_colors, dim=0).values |
|
max_colors = torch.max(chunk_colors, dim=0).values |
|
color_bounds = torch.cat([min_colors, max_colors]) |
|
chunk_data.extend([mean_bounds, scale_bounds, color_bounds]) |
|
|
|
|
|
|
|
normalized_means = (chunk_means - min_means) / (max_means - min_means) |
|
means_i = pack_111011( |
|
normalized_means[:, 0], |
|
normalized_means[:, 1], |
|
normalized_means[:, 2], |
|
) |
|
|
|
chunk_quats = quats[splat_idxs] |
|
quat_i = pack_rotation(chunk_quats) |
|
|
|
normalized_scales = (chunk_scales - min_scales) / (max_scales - min_scales) |
|
scales_i = pack_111011( |
|
normalized_scales[:, 0], |
|
normalized_scales[:, 1], |
|
normalized_scales[:, 2], |
|
) |
|
|
|
normalized_colors = (chunk_colors - min_colors) / (max_colors - min_colors) |
|
chunk_opacities = opacities[splat_idxs] |
|
chunk_opacities = 1 / (1 + torch.exp(-chunk_opacities)) |
|
chunk_opacities = chunk_opacities.unsqueeze(-1) |
|
normalized_colors_i = torch.cat([normalized_colors, chunk_opacities], dim=-1) |
|
color_i = pack_8888( |
|
normalized_colors_i[:, 0], |
|
normalized_colors_i[:, 1], |
|
normalized_colors_i[:, 2], |
|
normalized_colors_i[:, 3], |
|
) |
|
splat_data_chunk = torch.stack([means_i, quat_i, scales_i, color_i], dim=1) |
|
splat_data_chunk = splat_data_chunk.ravel().to(torch.int64) |
|
splat_data.extend([splat_data_chunk]) |
|
|
|
|
|
shN_chunk = shN[splat_idxs] |
|
shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256 |
|
shN_chunk_quantized = torch.clamp(torch.trunc(shN_chunk_quantized), 0, 255) |
|
shN_chunk_quantized = shN_chunk_quantized.to(torch.uint8) |
|
sh_data.extend([shN_chunk_quantized.ravel()]) |
|
|
|
float_dtype = np.dtype(np.float32).newbyteorder("<") |
|
uint32_dtype = np.dtype(np.uint32).newbyteorder("<") |
|
uint8_dtype = np.dtype(np.uint8) |
|
|
|
buffer.write( |
|
torch.cat(chunk_data).detach().cpu().numpy().astype(float_dtype).tobytes() |
|
) |
|
buffer.write( |
|
torch.cat(splat_data).detach().cpu().numpy().astype(uint32_dtype).tobytes() |
|
) |
|
buffer.write( |
|
torch.cat(sh_data).detach().cpu().numpy().astype(uint8_dtype).tobytes() |
|
) |
|
|
|
return buffer.getvalue() |
|
|
|
|
|
def splat2ply_bytes( |
|
means: torch.Tensor, |
|
scales: torch.Tensor, |
|
quats: torch.Tensor, |
|
opacities: torch.Tensor, |
|
sh0: torch.Tensor, |
|
shN: torch.Tensor, |
|
) -> bytes: |
|
"""Return the binary Ply file. Supported by almost all viewers. |
|
|
|
Args: |
|
means (torch.Tensor): Splat means. Shape (N, 3) |
|
scales (torch.Tensor): Splat scales. Shape (N, 3) |
|
quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
|
opacities (torch.Tensor): Splat opacities. Shape (N,) |
|
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) |
|
shN (torch.Tensor): Spherical harmonics. Shape (N, K*3) |
|
|
|
Returns: |
|
bytes: Binary Ply file representing the model. |
|
""" |
|
|
|
num_splats = means.shape[0] |
|
buffer = BytesIO() |
|
|
|
|
|
buffer.write(b"ply\n") |
|
buffer.write(b"format binary_little_endian 1.0\n") |
|
buffer.write(f"element vertex {num_splats}\n".encode()) |
|
buffer.write(b"property float x\n") |
|
buffer.write(b"property float y\n") |
|
buffer.write(b"property float z\n") |
|
for i, data in enumerate([sh0, shN]): |
|
prefix = "f_dc" if i == 0 else "f_rest" |
|
for j in range(data.shape[1]): |
|
buffer.write(f"property float {prefix}_{j}\n".encode()) |
|
buffer.write(b"property float opacity\n") |
|
for i in range(scales.shape[1]): |
|
buffer.write(f"property float scale_{i}\n".encode()) |
|
for i in range(quats.shape[1]): |
|
buffer.write(f"property float rot_{i}\n".encode()) |
|
buffer.write(b"end_header\n") |
|
|
|
|
|
splat_data = torch.cat( |
|
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1 |
|
) |
|
|
|
splat_data = splat_data.to(torch.float32) |
|
|
|
|
|
float_dtype = np.dtype(np.float32).newbyteorder("<") |
|
buffer.write(splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()) |
|
|
|
return buffer.getvalue() |
|
|
|
|
|
def splat2splat_bytes( |
|
means: torch.Tensor, |
|
scales: torch.Tensor, |
|
quats: torch.Tensor, |
|
opacities: torch.Tensor, |
|
sh0: torch.Tensor, |
|
) -> bytes: |
|
"""Return the binary Splat file. Supported by antimatter15 viewer. |
|
|
|
Args: |
|
means (torch.Tensor): Splat means. Shape (N, 3) |
|
scales (torch.Tensor): Splat scales. Shape (N, 3) |
|
quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
|
opacities (torch.Tensor): Splat opacities. Shape (N,) |
|
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3) |
|
|
|
Returns: |
|
bytes: Binary Splat file representing the model. |
|
""" |
|
|
|
|
|
scales = torch.exp(scales) |
|
sh0_color = sh2rgb(sh0) |
|
colors = torch.cat([sh0_color, torch.sigmoid(opacities).unsqueeze(-1)], dim=1) |
|
colors = (colors * 255).clamp(0, 255).to(torch.uint8) |
|
|
|
rots = (quats / torch.linalg.norm(quats, dim=1, keepdim=True)) * 128 + 128 |
|
rots = rots.clamp(0, 255).to(torch.uint8) |
|
|
|
|
|
num_splats = means.shape[0] |
|
indices = sort_centers(means, torch.arange(num_splats)) |
|
|
|
|
|
means = means[indices] |
|
scales = scales[indices] |
|
colors = colors[indices] |
|
rots = rots[indices] |
|
|
|
float_dtype = np.dtype(np.float32).newbyteorder("<") |
|
means_np = means.detach().cpu().numpy().astype(float_dtype) |
|
scales_np = scales.detach().cpu().numpy().astype(float_dtype) |
|
colors_np = colors.detach().cpu().numpy().astype(np.uint8) |
|
rots_np = rots.detach().cpu().numpy().astype(np.uint8) |
|
|
|
buffer = BytesIO() |
|
for i in range(num_splats): |
|
buffer.write(means_np[i].tobytes()) |
|
buffer.write(scales_np[i].tobytes()) |
|
buffer.write(colors_np[i].tobytes()) |
|
buffer.write(rots_np[i].tobytes()) |
|
|
|
return buffer.getvalue() |
|
|
|
|
|
def export_splats( |
|
means: torch.Tensor, |
|
scales: torch.Tensor, |
|
quats: torch.Tensor, |
|
opacities: torch.Tensor, |
|
sh0: torch.Tensor, |
|
shN: torch.Tensor, |
|
format: Literal["ply", "splat", "ply_compressed"] = "ply", |
|
save_to: Optional[str] = None, |
|
) -> bytes: |
|
"""Export a Gaussian Splats model to bytes. |
|
The three supported formats are: |
|
- ply: A standard PLY file format. Supported by most viewers. |
|
- splat: A custom Splat file format. Supported by antimatter15 viewer. |
|
- ply_compressed: A compressed PLY file format. Used by Supersplat viewer. |
|
|
|
Args: |
|
means (torch.Tensor): Splat means. Shape (N, 3) |
|
scales (torch.Tensor): Splat scales. Shape (N, 3) |
|
quats (torch.Tensor): Splat quaternions. Shape (N, 4) |
|
opacities (torch.Tensor): Splat opacities. Shape (N,) |
|
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 1, 3) |
|
shN (torch.Tensor): Spherical harmonics. Shape (N, K, 3) |
|
format (str): Export format. Options: "ply", "splat", "ply_compressed". Default: "ply" |
|
save_to (str): Output file path. If provided, the bytes will be written to file. |
|
""" |
|
total_splats = means.shape[0] |
|
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)" |
|
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)" |
|
assert quats.shape == (total_splats, 4), "Quaternions must be of shape (N, 4)" |
|
assert opacities.shape == (total_splats,), "Opacities must be of shape (N,)" |
|
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)" |
|
assert ( |
|
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3 |
|
), f"shN must be of shape (N, K, 3), got {shN.shape}" |
|
|
|
|
|
sh0 = sh0.squeeze(1) |
|
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) |
|
|
|
|
|
invalid_mask = ( |
|
torch.isnan(means).any(dim=1) |
|
| torch.isinf(means).any(dim=1) |
|
| torch.isnan(scales).any(dim=1) |
|
| torch.isinf(scales).any(dim=1) |
|
| torch.isnan(quats).any(dim=1) |
|
| torch.isinf(quats).any(dim=1) |
|
| torch.isnan(opacities).any(dim=0) |
|
| torch.isinf(opacities).any(dim=0) |
|
| torch.isnan(sh0).any(dim=1) |
|
| torch.isinf(sh0).any(dim=1) |
|
| torch.isnan(shN).any(dim=1) |
|
| torch.isinf(shN).any(dim=1) |
|
) |
|
|
|
|
|
valid_mask = ~invalid_mask |
|
means = means[valid_mask] |
|
scales = scales[valid_mask] |
|
quats = quats[valid_mask] |
|
opacities = opacities[valid_mask] |
|
sh0 = sh0[valid_mask] |
|
shN = shN[valid_mask] |
|
|
|
if format == "ply": |
|
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN) |
|
elif format == "splat": |
|
data = splat2splat_bytes(means, scales, quats, opacities, sh0) |
|
elif format == "ply_compressed": |
|
data = splat2ply_bytes_compressed(means, scales, quats, opacities, sh0, shN) |
|
else: |
|
raise ValueError(f"Unsupported format: {format}") |
|
|
|
if save_to: |
|
with open(save_to, "wb") as binary_file: |
|
binary_file.write(data) |
|
|
|
return data |
|
|