Spaces:
Build error
Build error
import torch | |
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor | |
from torch.nn import Module | |
import torch.nn.functional as F | |
from beartype import beartype | |
from beartype.typing import Tuple, Optional, List, Union | |
from einops.layers.torch import Rearrange | |
from einops import rearrange, repeat, reduce, pack, unpack | |
# from gateloop_transformer import SimpleGateLoopLayer as GateLoop | |
from modules.audio2motion.cfm.utils import * | |
from modules.audio2motion.cfm.attend import Attend | |
import math | |
from functools import partial | |
from torch.cuda.amp import autocast | |
# sinusoidal positions | |
class LearnedSinusoidalPosEmb(Module): | |
""" used by @crowsonkb """ | |
def __init__(self, dim): | |
super().__init__() | |
assert divisible_by(dim, 2) | |
half_dim = dim // 2 | |
self.weights = nn.Parameter(torch.randn(half_dim)) | |
def forward(self, x): | |
x = rearrange(x, 'b -> b 1') | |
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi | |
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) | |
return fouriered | |
# rotary positional embeddings | |
# https://arxiv.org/abs/2104.09864 | |
class RotaryEmbedding(Module): | |
def __init__(self, dim, theta = 50000): | |
super().__init__() | |
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
def device(self): | |
return self.inv_freq.device | |
def forward(self, t: Union[int, Tensor]): | |
if not torch.is_tensor(t): | |
t = torch.arange(t, device = self.device) | |
t = t.type_as(self.inv_freq) | |
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) | |
freqs = torch.cat((freqs, freqs), dim = -1) | |
return freqs | |
def rotate_half(x): | |
x1, x2 = x.chunk(2, dim = -1) | |
return torch.cat((-x2, x1), dim = -1) | |
def apply_rotary_pos_emb(pos, t): | |
return t * pos.cos() + rotate_half(t) * pos.sin() | |
# convolutional positional generating module | |
class ConvPositionEmbed(Module): | |
def __init__( | |
self, | |
dim, | |
*, | |
kernel_size, | |
groups = None | |
): | |
super().__init__() | |
assert is_odd(kernel_size) | |
groups = default(groups, dim) # full depthwise conv by default | |
self.dw_conv1d = nn.Sequential( | |
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2), | |
nn.GELU() | |
) | |
def forward(self, x): | |
x = rearrange(x, 'b n c -> b c n') | |
x = self.dw_conv1d(x) | |
return rearrange(x, 'b c n -> b n c') | |
# norms | |
class RMSNorm(Module): | |
def __init__( | |
self, | |
dim | |
): | |
super().__init__() | |
self.scale = dim ** 0.5 | |
self.gamma = nn.Parameter(torch.ones(dim)) | |
def forward(self, x): | |
return F.normalize(x, dim = -1) * self.scale * self.gamma | |
class AdaptiveRMSNorm(Module): | |
def __init__( | |
self, | |
dim, | |
cond_dim = None | |
): | |
super().__init__() | |
cond_dim = default(cond_dim, dim) | |
self.scale = dim ** 0.5 | |
self.to_gamma = nn.Linear(cond_dim, dim) | |
self.to_beta = nn.Linear(cond_dim, dim) | |
# init to identity | |
nn.init.zeros_(self.to_gamma.weight) | |
nn.init.ones_(self.to_gamma.bias) | |
nn.init.zeros_(self.to_beta.weight) | |
nn.init.zeros_(self.to_beta.bias) | |
def forward(self, x, *, cond): | |
normed = F.normalize(x, dim = -1) * self.scale | |
gamma, beta = self.to_gamma(cond), self.to_beta(cond) | |
gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta)) | |
return normed * gamma + beta | |
# attention | |
class MultiheadRMSNorm(Module): | |
def __init__(self, dim, heads): | |
super().__init__() | |
self.scale = dim ** 0.5 | |
self.gamma = nn.Parameter(torch.ones(heads, 1, dim)) | |
def forward(self, x): | |
return F.normalize(x, dim = -1) * self.gamma * self.scale | |
class Attention(Module): | |
def __init__( | |
self, | |
dim, | |
dim_head = 64, | |
heads = 8, | |
dropout = 0, | |
flash = False, | |
qk_norm = False, | |
qk_norm_scale = 10 | |
): | |
super().__init__() | |
self.heads = heads | |
dim_inner = dim_head * heads | |
scale = qk_norm_scale if qk_norm else None | |
self.attend = Attend(dropout, flash = flash, scale = scale) | |
self.qk_norm = qk_norm | |
if qk_norm: | |
self.q_norm = MultiheadRMSNorm(dim_head, heads = heads) | |
self.k_norm = MultiheadRMSNorm(dim_head, heads = heads) | |
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) | |
self.to_out = nn.Linear(dim_inner, dim, bias = False) | |
def forward(self, x, mask = None, rotary_emb = None): | |
h = self.heads | |
q, k, v = self.to_qkv(x).chunk(3, dim = -1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | |
if self.qk_norm: | |
q = self.q_norm(q) | |
k = self.k_norm(k) | |
if exists(rotary_emb): | |
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k)) | |
out = self.attend(q, k, v, mask = mask) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
return self.to_out(out) | |
# feedforward | |
class GEGLU(Module): | |
def forward(self, x): | |
x, gate = x.chunk(2, dim = -1) | |
return F.gelu(gate) * x | |
def FeedForward(dim, mult = 4, dropout = 0.): | |
dim_inner = int(dim * mult * 2 / 3) | |
return nn.Sequential( | |
nn.Linear(dim, dim_inner * 2), | |
GEGLU(), | |
nn.Dropout(dropout), | |
nn.Linear(dim_inner, dim) | |
) | |
# transformer | |
class Transformer(Module): | |
def __init__( | |
self, | |
dim, | |
*, | |
depth, | |
dim_head = 64, | |
heads = 8, | |
ff_mult = 4, | |
attn_dropout = 0., | |
ff_dropout = 0., | |
num_register_tokens = 0., | |
attn_flash = False, | |
adaptive_rmsnorm = False, | |
adaptive_rmsnorm_cond_dim_in = None, | |
use_unet_skip_connection = False, | |
skip_connect_scale = None, | |
attn_qk_norm = False, | |
use_gateloop_layers = False | |
): | |
super().__init__() | |
assert divisible_by(depth, 2) | |
self.layers = nn.ModuleList([]) | |
self.rotary_emb = RotaryEmbedding(dim = dim_head) | |
self.num_register_tokens = num_register_tokens | |
self.has_register_tokens = num_register_tokens > 0 | |
if self.has_register_tokens: | |
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) | |
if adaptive_rmsnorm: | |
rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in) | |
else: | |
rmsnorm_klass = RMSNorm | |
self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5) | |
for ind in range(depth): | |
layer = ind + 1 | |
has_skip = use_unet_skip_connection and layer > (depth // 2) | |
self.layers.append(nn.ModuleList([ | |
nn.Linear(dim * 2, dim) if has_skip else None, | |
# GateLoop(dim = dim) if use_gateloop_layers else None, | |
None, | |
rmsnorm_klass(dim = dim), | |
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm), | |
rmsnorm_klass(dim = dim), | |
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
])) | |
self.final_norm = RMSNorm(dim) | |
def device(self): | |
return next(self.parameters()).device | |
def forward( | |
self, | |
x, | |
mask = None, | |
adaptive_rmsnorm_cond = None | |
): | |
batch, seq_len, *_ = x.shape | |
# add register tokens to the left | |
if self.has_register_tokens: | |
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) | |
x, ps = pack([register_tokens, x], 'b * d') | |
if exists(mask): | |
mask = F.pad(mask, (self.num_register_tokens, 0), value = True) | |
# keep track of skip connections | |
skip_connects = [] | |
# rotary embeddings | |
positions = seq_len | |
if self.has_register_tokens: | |
main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long) | |
register_positions = torch.full((self.num_register_tokens,), -10000, device = self.device, dtype = torch.long) | |
positions = torch.cat((register_positions, main_positions)) | |
rotary_emb = self.rotary_emb(positions) | |
# adaptive rmsnorm | |
rmsnorm_kwargs = dict() | |
if exists(adaptive_rmsnorm_cond): | |
rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond) | |
# going through the attention layers | |
for skip_combiner, maybe_gateloop, attn_prenorm, attn, ff_prenorm, ff in self.layers: | |
# in the paper, they use a u-net like skip connection | |
# unclear how much this helps, as no ablations or further numbers given besides a brief one-two sentence mention | |
if not exists(skip_combiner): | |
skip_connects.append(x) | |
else: | |
skip_connect = skip_connects.pop() * self.skip_connect_scale | |
x = torch.cat((x, skip_connect), dim = -1) | |
x = skip_combiner(x) | |
if exists(maybe_gateloop): | |
x = maybe_gateloop(x) + x | |
attn_input = attn_prenorm(x, **rmsnorm_kwargs) | |
x = attn(attn_input, mask = mask, rotary_emb = rotary_emb) + x | |
ff_input = ff_prenorm(x, **rmsnorm_kwargs) | |
x = ff(ff_input) + x | |
# remove the register tokens | |
if self.has_register_tokens: | |
_, x = unpack(x, ps, 'b * d') | |
return self.final_norm(x) | |
if __name__ == '__main__': | |
# Initialize the Transformer | |
transformer = Transformer(dim=512, depth=6, dim_head=64, heads=8, ff_mult=4) | |
# Create random input tensor | |
input_tensor = torch.randn(1, 10, 512) # Assuming input shape is (batch_size, sequence_length, input_dim) | |
# Forward pass through the Transformer | |
output = transformer(input_tensor) | |
# Print the shape of the output | |
print(output.shape) | |