|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class PositionalEmbedding(nn.Module): |
|
""" |
|
Taken from https://github.com/NVlabs/edm |
|
""" |
|
|
|
def __init__(self, num_channels, max_positions=10000, endpoint=False): |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.max_positions = max_positions |
|
self.endpoint = endpoint |
|
freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32) |
|
freqs = 2 * freqs / self.num_channels |
|
freqs = (1 / self.max_positions) ** freqs |
|
self.register_buffer("freqs", freqs) |
|
|
|
def forward(self, x): |
|
x = torch.outer(x, self.freqs) |
|
out = torch.cat([x.cos(), x.sin()], dim=1) |
|
return out.to(x.dtype) |
|
|
|
|
|
|
|
|
|
class FourierEmbedding(nn.Module): |
|
""" |
|
Taken from https://github.com/NVlabs/edm |
|
""" |
|
|
|
def __init__(self, num_channels, scale=16): |
|
super().__init__() |
|
self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) |
|
|
|
def forward(self, x): |
|
x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) |
|
x = torch.cat([x.cos(), x.sin()], dim=1) |
|
return x |
|
|