KoFace-AI / diffusion_model /network /timestep_embedding.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
import torch
import torch.nn as nn
import math
class SinusoidalEmbedding(nn.Module):
def __init__(self, embed_dim : int, theta : int = 10000):
"""
Creates sinusoidal embeddings for timesteps.
Args:
embed_dim: The dimensionality of the embedding.
theta: The base for the log-spaced frequencies.
"""
super().__init__()
self.embed_dim = embed_dim
self.theta = theta
def forward(self, x):
"""
Computes sinusoidal embeddings for the input timesteps.
Args:
x: A 1D torch.Tensor of timesteps (shape: [batch_size]).
Returns:
A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]).
"""
assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor
assert x.ndim == 1 # Input must be a 1D tensor
assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer
half_dim = self.embed_dim // 2
# Create a sequence of log-spaced frequencies
embeddings = math.log(self.theta) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings)
# Outer product: timesteps x frequencies
embeddings = x[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
# Handle odd embedding dimensions
if self.embed_dim % 2 == 1:
embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1)
return embeddings