|
import torch |
|
import torch.nn as nn |
|
from einops import einsum, rearrange, repeat |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
"""For the sake of simplicity, this encodes values in the range [0, 1].""" |
|
|
|
frequencies: Float[Tensor, "frequency phase"] |
|
phases: Float[Tensor, "frequency phase"] |
|
|
|
def __init__(self, num_octaves: int): |
|
super().__init__() |
|
octaves = torch.arange(num_octaves).float() |
|
|
|
|
|
frequencies = 2 * torch.pi * 2**octaves |
|
frequencies = repeat(frequencies, "f -> f p", p=2) |
|
self.register_buffer("frequencies", frequencies, persistent=False) |
|
|
|
|
|
phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) |
|
phases = repeat(phases, "p -> f p", f=num_octaves) |
|
self.register_buffer("phases", phases, persistent=False) |
|
|
|
def forward( |
|
self, |
|
samples: Float[Tensor, "*batch dim"], |
|
) -> Float[Tensor, "*batch embedded_dim"]: |
|
samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") |
|
return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") |
|
|
|
def d_out(self, dimensionality: int): |
|
return self.frequencies.numel() * dimensionality |
|
|