AnySplat / src /post_opt /exporter.py
alexnasa's picture
Upload 243 files
2568013 verified
import math
import struct
from io import BytesIO
from typing import Literal, Optional
import numpy as np
import torch
def sh2rgb(sh: torch.Tensor) -> torch.Tensor:
"""Convert Sphere Harmonics to RGB
Args:
sh (torch.Tensor): SH tensor
Returns:
torch.Tensor: RGB tensor
"""
C0 = 0.28209479177387814
return sh * C0 + 0.5
def part1by2_vec(x: torch.Tensor) -> torch.Tensor:
"""Interleave bits of x with 0s
Args:
x (torch.Tensor): Input tensor. Shape (N,)
Returns:
torch.Tensor: Output tensor. Shape (N,)
"""
x = x & 0x000003FF
x = (x ^ (x << 16)) & 0xFF0000FF
x = (x ^ (x << 8)) & 0x0300F00F
x = (x ^ (x << 4)) & 0x030C30C3
x = (x ^ (x << 2)) & 0x09249249
return x
def encode_morton3_vec(
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> torch.Tensor:
"""Compute Morton codes for 3D coordinates
Args:
x (torch.Tensor): X coordinates. Shape (N,)
y (torch.Tensor): Y coordinates. Shape (N,)
z (torch.Tensor): Z coordinates. Shape (N,)
Returns:
torch.Tensor: Morton codes. Shape (N,)
"""
return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x)
def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""Sort centers based on Morton codes
Args:
centers (torch.Tensor): Centers. Shape (N, 3)
indices (torch.Tensor): Indices. Shape (N,)
Returns:
torch.Tensor: Sorted indices. Shape (N,)
"""
# Compute min and max values in a single operation
min_vals, _ = torch.min(centers, dim=0)
max_vals, _ = torch.max(centers, dim=0)
# Compute the scaling factors
lengths = max_vals - min_vals
lengths[lengths == 0] = 1 # Prevent division by zero
# Normalize and scale to 10-bit integer range (0-1024)
scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32)
# Extract x, y, z coordinates
x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2]
# Compute Morton codes using vectorized operations
morton = encode_morton3_vec(x, y, z)
# Sort indices based on Morton codes
sorted_indices = indices[torch.argsort(morton).to(indices.device)]
return sorted_indices
def pack_unorm(value: torch.Tensor, bits: int) -> torch.Tensor:
"""Pack a floating point value into an unsigned integer with a given number of bits.
Args:
value (torch.Tensor): Floating point value to pack. Shape (N,)
bits (int): Number of bits to pack into.
Returns:
torch.Tensor: Packed value. Shape (N,)
"""
t = (1 << bits) - 1
packed = torch.clamp((value * t + 0.5).floor(), min=0, max=t)
# Convert to integer type
return packed.to(torch.int64)
def pack_111011(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""Pack three floating point values into a 32-bit integer with 11, 10, and 11 bits.
Args:
x (torch.Tensor): X component. Shape (N,)
y (torch.Tensor): Y component. Shape (N,)
z (torch.Tensor): Z component. Shape (N,)
Returns:
torch.Tensor: Packed values. Shape (N,)
"""
# Pack each component using pack_unorm
packed_x = pack_unorm(x, 11) << 21
packed_y = pack_unorm(y, 10) << 11
packed_z = pack_unorm(z, 11)
# Combine the packed values using bitwise OR
return packed_x | packed_y | packed_z
def pack_8888(
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, w: torch.Tensor
) -> torch.Tensor:
"""Pack four floating point values into a 32-bit integer with 8 bits each.
Args:
x (torch.Tensor): X component. Shape (N,)
y (torch.Tensor): Y component. Shape (N,)
z (torch.Tensor): Z component. Shape (N,)
w (torch.Tensor): W component. Shape (N,)
Returns:
torch.Tensor: Packed values. Shape (N,)
"""
# Pack each component using pack_unorm
packed_x = pack_unorm(x, 8) << 24
packed_y = pack_unorm(y, 8) << 16
packed_z = pack_unorm(z, 8) << 8
packed_w = pack_unorm(w, 8)
# Combine the packed values using bitwise OR
return packed_x | packed_y | packed_z | packed_w
def pack_rotation(q: torch.Tensor) -> torch.Tensor:
"""Pack a quaternion into a 32-bit integer.
Args:
q (torch.Tensor): Quaternions. Shape (N, 4)
Returns:
torch.Tensor: Packed values. Shape (N,)
"""
# Normalize each quaternion
norms = torch.linalg.norm(q, dim=-1, keepdim=True)
q = q / norms
# Find the largest component index for each quaternion
largest_components = torch.argmax(torch.abs(q), dim=-1)
# Flip quaternions where the largest component is negative
batch_indices = torch.arange(q.size(0), device=q.device)
largest_values = q[batch_indices, largest_components]
flip_mask = largest_values < 0
q[flip_mask] *= -1
# Precomputed indices for the components to pack (excluding largest)
precomputed_indices = torch.tensor(
[[1, 2, 3], [0, 2, 3], [0, 1, 3], [0, 1, 2]], dtype=torch.long, device=q.device
)
# Gather components to pack for each quaternion
pack_indices = precomputed_indices[largest_components]
components_to_pack = q[batch_indices[:, None], pack_indices]
# Scale and pack each component into 10-bit integers
norm = math.sqrt(2) * 0.5
scaled = components_to_pack * norm + 0.5
packed = pack_unorm(scaled, 10) # Assuming pack_unorm is vectorized
# Combine into the final 32-bit integer
largest_packed = largest_components.to(torch.int64) << 30
c0_packed = packed[:, 0] << 20
c1_packed = packed[:, 1] << 10
c2_packed = packed[:, 2]
result = largest_packed | c0_packed | c1_packed | c2_packed
return result
def splat2ply_bytes_compressed(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
chunk_max_size: int = 256,
opacity_threshold: float = 1 / 255,
) -> bytes:
"""Return the binary compressed Ply file. Used by Supersplat viewer.
Args:
means (torch.Tensor): Splat means. Shape (N, 3)
scales (torch.Tensor): Splat scales. Shape (N, 3)
quats (torch.Tensor): Splat quaternions. Shape (N, 4)
opacities (torch.Tensor): Splat opacities. Shape (N,)
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3)
shN (torch.Tensor): Spherical harmonics. Shape (N, K*3)
chunk_max_size (int): Maximum number of splats per chunk. Default: 256
opacity_threshold (float): Opacity threshold. Default: 1 / 255
Returns:
bytes: Binary compressed Ply file representing the model.
"""
# Filter the splats with too low opacity
mask = torch.sigmoid(opacities) > opacity_threshold
means = means[mask]
scales = scales[mask]
sh0_colors = sh2rgb(sh0)
sh0_colors = sh0_colors[mask]
shN = shN[mask]
quats = quats[mask]
opacities = opacities[mask]
num_splats = means.shape[0]
n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0)
indices = torch.arange(num_splats)
indices = sort_centers(means, indices)
float_properties = [
"min_x",
"min_y",
"min_z",
"max_x",
"max_y",
"max_z",
"min_scale_x",
"min_scale_y",
"min_scale_z",
"max_scale_x",
"max_scale_y",
"max_scale_z",
"min_r",
"min_g",
"min_b",
"max_r",
"max_g",
"max_b",
]
uint_properties = [
"packed_position",
"packed_rotation",
"packed_scale",
"packed_color",
]
buffer = BytesIO()
# Write PLY header
buffer.write(b"ply\n")
buffer.write(b"format binary_little_endian 1.0\n")
buffer.write(f"element chunk {n_chunks}\n".encode())
for prop in float_properties:
buffer.write(f"property float {prop}\n".encode())
buffer.write(f"element vertex {num_splats}\n".encode())
for prop in uint_properties:
buffer.write(f"property uint {prop}\n".encode())
buffer.write(f"element sh {num_splats}\n".encode())
for j in range(shN.shape[1]):
buffer.write(f"property uchar f_rest_{j}\n".encode())
buffer.write(b"end_header\n")
chunk_data = []
splat_data = []
sh_data = []
for chunk_idx in range(n_chunks):
chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats)
chunk_start_idx = chunk_idx * chunk_max_size
splat_idxs = indices[chunk_start_idx:chunk_end_idx]
# Bounds
# Means
chunk_means = means[splat_idxs]
min_means = torch.min(chunk_means, dim=0).values
max_means = torch.max(chunk_means, dim=0).values
mean_bounds = torch.cat([min_means, max_means])
# Scales
chunk_scales = scales[splat_idxs]
min_scales = torch.min(chunk_scales, dim=0).values
max_scales = torch.max(chunk_scales, dim=0).values
min_scales = torch.clamp(min_scales, -20, 20)
max_scales = torch.clamp(max_scales, -20, 20)
scale_bounds = torch.cat([min_scales, max_scales])
# Colors
chunk_colors = sh0_colors[splat_idxs]
min_colors = torch.min(chunk_colors, dim=0).values
max_colors = torch.max(chunk_colors, dim=0).values
color_bounds = torch.cat([min_colors, max_colors])
chunk_data.extend([mean_bounds, scale_bounds, color_bounds])
# Quantized properties:
# Means
normalized_means = (chunk_means - min_means) / (max_means - min_means)
means_i = pack_111011(
normalized_means[:, 0],
normalized_means[:, 1],
normalized_means[:, 2],
)
# Quaternions
chunk_quats = quats[splat_idxs]
quat_i = pack_rotation(chunk_quats)
# Scales
normalized_scales = (chunk_scales - min_scales) / (max_scales - min_scales)
scales_i = pack_111011(
normalized_scales[:, 0],
normalized_scales[:, 1],
normalized_scales[:, 2],
)
# Colors
normalized_colors = (chunk_colors - min_colors) / (max_colors - min_colors)
chunk_opacities = opacities[splat_idxs]
chunk_opacities = 1 / (1 + torch.exp(-chunk_opacities))
chunk_opacities = chunk_opacities.unsqueeze(-1)
normalized_colors_i = torch.cat([normalized_colors, chunk_opacities], dim=-1)
color_i = pack_8888(
normalized_colors_i[:, 0],
normalized_colors_i[:, 1],
normalized_colors_i[:, 2],
normalized_colors_i[:, 3],
)
splat_data_chunk = torch.stack([means_i, quat_i, scales_i, color_i], dim=1)
splat_data_chunk = splat_data_chunk.ravel().to(torch.int64)
splat_data.extend([splat_data_chunk])
# Quantized spherical harmonics
shN_chunk = shN[splat_idxs]
shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256
shN_chunk_quantized = torch.clamp(torch.trunc(shN_chunk_quantized), 0, 255)
shN_chunk_quantized = shN_chunk_quantized.to(torch.uint8)
sh_data.extend([shN_chunk_quantized.ravel()])
float_dtype = np.dtype(np.float32).newbyteorder("<")
uint32_dtype = np.dtype(np.uint32).newbyteorder("<")
uint8_dtype = np.dtype(np.uint8)
buffer.write(
torch.cat(chunk_data).detach().cpu().numpy().astype(float_dtype).tobytes()
)
buffer.write(
torch.cat(splat_data).detach().cpu().numpy().astype(uint32_dtype).tobytes()
)
buffer.write(
torch.cat(sh_data).detach().cpu().numpy().astype(uint8_dtype).tobytes()
)
return buffer.getvalue()
def splat2ply_bytes(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
) -> bytes:
"""Return the binary Ply file. Supported by almost all viewers.
Args:
means (torch.Tensor): Splat means. Shape (N, 3)
scales (torch.Tensor): Splat scales. Shape (N, 3)
quats (torch.Tensor): Splat quaternions. Shape (N, 4)
opacities (torch.Tensor): Splat opacities. Shape (N,)
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3)
shN (torch.Tensor): Spherical harmonics. Shape (N, K*3)
Returns:
bytes: Binary Ply file representing the model.
"""
num_splats = means.shape[0]
buffer = BytesIO()
# Write PLY header
buffer.write(b"ply\n")
buffer.write(b"format binary_little_endian 1.0\n")
buffer.write(f"element vertex {num_splats}\n".encode())
buffer.write(b"property float x\n")
buffer.write(b"property float y\n")
buffer.write(b"property float z\n")
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
buffer.write(f"property float {prefix}_{j}\n".encode())
buffer.write(b"property float opacity\n")
for i in range(scales.shape[1]):
buffer.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
buffer.write(f"property float rot_{i}\n".encode())
buffer.write(b"end_header\n")
# Concatenate all tensors in the correct order
splat_data = torch.cat(
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
)
# Ensure correct dtype
splat_data = splat_data.to(torch.float32)
# Write binary data
float_dtype = np.dtype(np.float32).newbyteorder("<")
buffer.write(splat_data.detach().cpu().numpy().astype(float_dtype).tobytes())
return buffer.getvalue()
def splat2splat_bytes(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
) -> bytes:
"""Return the binary Splat file. Supported by antimatter15 viewer.
Args:
means (torch.Tensor): Splat means. Shape (N, 3)
scales (torch.Tensor): Splat scales. Shape (N, 3)
quats (torch.Tensor): Splat quaternions. Shape (N, 4)
opacities (torch.Tensor): Splat opacities. Shape (N,)
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 3)
Returns:
bytes: Binary Splat file representing the model.
"""
# Preprocess
scales = torch.exp(scales)
sh0_color = sh2rgb(sh0)
colors = torch.cat([sh0_color, torch.sigmoid(opacities).unsqueeze(-1)], dim=1)
colors = (colors * 255).clamp(0, 255).to(torch.uint8)
rots = (quats / torch.linalg.norm(quats, dim=1, keepdim=True)) * 128 + 128
rots = rots.clamp(0, 255).to(torch.uint8)
# Sort splats
num_splats = means.shape[0]
indices = sort_centers(means, torch.arange(num_splats))
# Reorder everything
means = means[indices]
scales = scales[indices]
colors = colors[indices]
rots = rots[indices]
float_dtype = np.dtype(np.float32).newbyteorder("<")
means_np = means.detach().cpu().numpy().astype(float_dtype)
scales_np = scales.detach().cpu().numpy().astype(float_dtype)
colors_np = colors.detach().cpu().numpy().astype(np.uint8)
rots_np = rots.detach().cpu().numpy().astype(np.uint8)
buffer = BytesIO()
for i in range(num_splats):
buffer.write(means_np[i].tobytes())
buffer.write(scales_np[i].tobytes())
buffer.write(colors_np[i].tobytes())
buffer.write(rots_np[i].tobytes())
return buffer.getvalue()
def export_splats(
means: torch.Tensor,
scales: torch.Tensor,
quats: torch.Tensor,
opacities: torch.Tensor,
sh0: torch.Tensor,
shN: torch.Tensor,
format: Literal["ply", "splat", "ply_compressed"] = "ply",
save_to: Optional[str] = None,
) -> bytes:
"""Export a Gaussian Splats model to bytes.
The three supported formats are:
- ply: A standard PLY file format. Supported by most viewers.
- splat: A custom Splat file format. Supported by antimatter15 viewer.
- ply_compressed: A compressed PLY file format. Used by Supersplat viewer.
Args:
means (torch.Tensor): Splat means. Shape (N, 3)
scales (torch.Tensor): Splat scales. Shape (N, 3)
quats (torch.Tensor): Splat quaternions. Shape (N, 4)
opacities (torch.Tensor): Splat opacities. Shape (N,)
sh0 (torch.Tensor): Spherical harmonics. Shape (N, 1, 3)
shN (torch.Tensor): Spherical harmonics. Shape (N, K, 3)
format (str): Export format. Options: "ply", "splat", "ply_compressed". Default: "ply"
save_to (str): Output file path. If provided, the bytes will be written to file.
"""
total_splats = means.shape[0]
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
assert quats.shape == (total_splats, 4), "Quaternions must be of shape (N, 4)"
assert opacities.shape == (total_splats,), "Opacities must be of shape (N,)"
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
assert (
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
), f"shN must be of shape (N, K, 3), got {shN.shape}"
# Reshape spherical harmonics
sh0 = sh0.squeeze(1) # Shape (N, 3)
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
# Check for NaN or Inf values
invalid_mask = (
torch.isnan(means).any(dim=1)
| torch.isinf(means).any(dim=1)
| torch.isnan(scales).any(dim=1)
| torch.isinf(scales).any(dim=1)
| torch.isnan(quats).any(dim=1)
| torch.isinf(quats).any(dim=1)
| torch.isnan(opacities).any(dim=0)
| torch.isinf(opacities).any(dim=0)
| torch.isnan(sh0).any(dim=1)
| torch.isinf(sh0).any(dim=1)
| torch.isnan(shN).any(dim=1)
| torch.isinf(shN).any(dim=1)
)
# Filter out invalid entries
valid_mask = ~invalid_mask
means = means[valid_mask]
scales = scales[valid_mask]
quats = quats[valid_mask]
opacities = opacities[valid_mask]
sh0 = sh0[valid_mask]
shN = shN[valid_mask]
if format == "ply":
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
elif format == "splat":
data = splat2splat_bytes(means, scales, quats, opacities, sh0)
elif format == "ply_compressed":
data = splat2ply_bytes_compressed(means, scales, quats, opacities, sh0, shN)
else:
raise ValueError(f"Unsupported format: {format}")
if save_to:
with open(save_to, "wb") as binary_file:
binary_file.write(data)
return data