Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import math | |
class SinusoidalEmbedding(nn.Module): | |
def __init__(self, embed_dim : int, theta : int = 10000): | |
""" | |
Creates sinusoidal embeddings for timesteps. | |
Args: | |
embed_dim: The dimensionality of the embedding. | |
theta: The base for the log-spaced frequencies. | |
""" | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.theta = theta | |
def forward(self, x): | |
""" | |
Computes sinusoidal embeddings for the input timesteps. | |
Args: | |
x: A 1D torch.Tensor of timesteps (shape: [batch_size]). | |
Returns: | |
A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]). | |
""" | |
assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor | |
assert x.ndim == 1 # Input must be a 1D tensor | |
assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer | |
half_dim = self.embed_dim // 2 | |
# Create a sequence of log-spaced frequencies | |
embeddings = math.log(self.theta) / (half_dim - 1) | |
embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings) | |
# Outer product: timesteps x frequencies | |
embeddings = x[:, None] * embeddings[None, :] | |
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
# Handle odd embedding dimensions | |
if self.embed_dim % 2 == 1: | |
embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1) | |
return embeddings |