File size: 1,606 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import math

class SinusoidalEmbedding(nn.Module):
    def __init__(self, embed_dim : int, theta : int = 10000):
        """
        Creates sinusoidal embeddings for timesteps.

        Args:
            embed_dim: The dimensionality of the embedding.
            theta: The base for the log-spaced frequencies.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.theta = theta

    def forward(self, x):
        """
        Computes sinusoidal embeddings for the input timesteps.

        Args:
            x: A 1D torch.Tensor of timesteps (shape: [batch_size]).

        Returns:
            A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]).
        """
        assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor
        assert x.ndim == 1 # Input must be a 1D tensor
        assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer

        half_dim = self.embed_dim // 2
        # Create a sequence of log-spaced frequencies
        embeddings = math.log(self.theta) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings)
        # Outer product: timesteps x frequencies
        embeddings = x[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # Handle odd embedding dimensions
        if self.embed_dim % 2 == 1:
            embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1)
        return embeddings