AnySplat / src /model /encodings /positional_encoding.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
import torch.nn as nn
from einops import einsum, rearrange, repeat
from jaxtyping import Float
from torch import Tensor
class PositionalEncoding(nn.Module):
"""For the sake of simplicity, this encodes values in the range [0, 1]."""
frequencies: Float[Tensor, "frequency phase"]
phases: Float[Tensor, "frequency phase"]
def __init__(self, num_octaves: int):
super().__init__()
octaves = torch.arange(num_octaves).float()
# The lowest frequency has a period of 1.
frequencies = 2 * torch.pi * 2**octaves
frequencies = repeat(frequencies, "f -> f p", p=2)
self.register_buffer("frequencies", frequencies, persistent=False)
# Choose the phases to match sine and cosine.
phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32)
phases = repeat(phases, "p -> f p", f=num_octaves)
self.register_buffer("phases", phases, persistent=False)
def forward(
self,
samples: Float[Tensor, "*batch dim"],
) -> Float[Tensor, "*batch embedded_dim"]:
samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p")
return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)")
def d_out(self, dimensionality: int):
return self.frequencies.numel() * dimensionality