Spaces:
Running
Running
File size: 1,606 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
import math
class SinusoidalEmbedding(nn.Module):
def __init__(self, embed_dim : int, theta : int = 10000):
"""
Creates sinusoidal embeddings for timesteps.
Args:
embed_dim: The dimensionality of the embedding.
theta: The base for the log-spaced frequencies.
"""
super().__init__()
self.embed_dim = embed_dim
self.theta = theta
def forward(self, x):
"""
Computes sinusoidal embeddings for the input timesteps.
Args:
x: A 1D torch.Tensor of timesteps (shape: [batch_size]).
Returns:
A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]).
"""
assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor
assert x.ndim == 1 # Input must be a 1D tensor
assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer
half_dim = self.embed_dim // 2
# Create a sequence of log-spaced frequencies
embeddings = math.log(self.theta) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings)
# Outer product: timesteps x frequencies
embeddings = x[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
# Handle odd embedding dimensions
if self.embed_dim % 2 == 1:
embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1)
return embeddings |