mrbear1024's picture
init project
8eb4303
import torch
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
from torch.nn import Module
import torch.nn.functional as F
from beartype import beartype
from beartype.typing import Tuple, Optional, List, Union
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
# from gateloop_transformer import SimpleGateLoopLayer as GateLoop
from modules.audio2motion.cfm.utils import *
from modules.audio2motion.cfm.attend import Attend
import math
from functools import partial
from torch.cuda.amp import autocast
# sinusoidal positions
class LearnedSinusoidalPosEmb(Module):
""" used by @crowsonkb """
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
return fouriered
# rotary positional embeddings
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(Module):
def __init__(self, dim, theta = 50000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
@property
def device(self):
return self.inv_freq.device
@autocast(enabled = False)
@beartype
def forward(self, t: Union[int, Tensor]):
if not torch.is_tensor(t):
t = torch.arange(t, device = self.device)
t = t.type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs
def rotate_half(x):
x1, x2 = x.chunk(2, dim = -1)
return torch.cat((-x2, x1), dim = -1)
@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()
# convolutional positional generating module
class ConvPositionEmbed(Module):
def __init__(
self,
dim,
*,
kernel_size,
groups = None
):
super().__init__()
assert is_odd(kernel_size)
groups = default(groups, dim) # full depthwise conv by default
self.dw_conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
nn.GELU()
)
def forward(self, x):
x = rearrange(x, 'b n c -> b c n')
x = self.dw_conv1d(x)
return rearrange(x, 'b c n -> b n c')
# norms
class RMSNorm(Module):
def __init__(
self,
dim
):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
class AdaptiveRMSNorm(Module):
def __init__(
self,
dim,
cond_dim = None
):
super().__init__()
cond_dim = default(cond_dim, dim)
self.scale = dim ** 0.5
self.to_gamma = nn.Linear(cond_dim, dim)
self.to_beta = nn.Linear(cond_dim, dim)
# init to identity
nn.init.zeros_(self.to_gamma.weight)
nn.init.ones_(self.to_gamma.bias)
nn.init.zeros_(self.to_beta.weight)
nn.init.zeros_(self.to_beta.bias)
def forward(self, x, *, cond):
normed = F.normalize(x, dim = -1) * self.scale
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta))
return normed * gamma + beta
# attention
class MultiheadRMSNorm(Module):
def __init__(self, dim, heads):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.gamma * self.scale
class Attention(Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
dropout = 0,
flash = False,
qk_norm = False,
qk_norm_scale = 10
):
super().__init__()
self.heads = heads
dim_inner = dim_head * heads
scale = qk_norm_scale if qk_norm else None
self.attend = Attend(dropout, flash = flash, scale = scale)
self.qk_norm = qk_norm
if qk_norm:
self.q_norm = MultiheadRMSNorm(dim_head, heads = heads)
self.k_norm = MultiheadRMSNorm(dim_head, heads = heads)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(self, x, mask = None, rotary_emb = None):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
if self.qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if exists(rotary_emb):
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))
out = self.attend(q, k, v, mask = mask)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# feedforward
class GEGLU(Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return F.gelu(gate) * x
def FeedForward(dim, mult = 4, dropout = 0.):
dim_inner = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.Linear(dim, dim_inner * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim)
)
# transformer
class Transformer(Module):
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
num_register_tokens = 0.,
attn_flash = False,
adaptive_rmsnorm = False,
adaptive_rmsnorm_cond_dim_in = None,
use_unet_skip_connection = False,
skip_connect_scale = None,
attn_qk_norm = False,
use_gateloop_layers = False
):
super().__init__()
assert divisible_by(depth, 2)
self.layers = nn.ModuleList([])
self.rotary_emb = RotaryEmbedding(dim = dim_head)
self.num_register_tokens = num_register_tokens
self.has_register_tokens = num_register_tokens > 0
if self.has_register_tokens:
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
if adaptive_rmsnorm:
rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in)
else:
rmsnorm_klass = RMSNorm
self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5)
for ind in range(depth):
layer = ind + 1
has_skip = use_unet_skip_connection and layer > (depth // 2)
self.layers.append(nn.ModuleList([
nn.Linear(dim * 2, dim) if has_skip else None,
# GateLoop(dim = dim) if use_gateloop_layers else None,
None,
rmsnorm_klass(dim = dim),
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm),
rmsnorm_klass(dim = dim),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.final_norm = RMSNorm(dim)
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
x,
mask = None,
adaptive_rmsnorm_cond = None
):
batch, seq_len, *_ = x.shape
# add register tokens to the left
if self.has_register_tokens:
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
x, ps = pack([register_tokens, x], 'b * d')
if exists(mask):
mask = F.pad(mask, (self.num_register_tokens, 0), value = True)
# keep track of skip connections
skip_connects = []
# rotary embeddings
positions = seq_len
if self.has_register_tokens:
main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long)
register_positions = torch.full((self.num_register_tokens,), -10000, device = self.device, dtype = torch.long)
positions = torch.cat((register_positions, main_positions))
rotary_emb = self.rotary_emb(positions)
# adaptive rmsnorm
rmsnorm_kwargs = dict()
if exists(adaptive_rmsnorm_cond):
rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond)
# going through the attention layers
for skip_combiner, maybe_gateloop, attn_prenorm, attn, ff_prenorm, ff in self.layers:
# in the paper, they use a u-net like skip connection
# unclear how much this helps, as no ablations or further numbers given besides a brief one-two sentence mention
if not exists(skip_combiner):
skip_connects.append(x)
else:
skip_connect = skip_connects.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim = -1)
x = skip_combiner(x)
if exists(maybe_gateloop):
x = maybe_gateloop(x) + x
attn_input = attn_prenorm(x, **rmsnorm_kwargs)
x = attn(attn_input, mask = mask, rotary_emb = rotary_emb) + x
ff_input = ff_prenorm(x, **rmsnorm_kwargs)
x = ff(ff_input) + x
# remove the register tokens
if self.has_register_tokens:
_, x = unpack(x, ps, 'b * d')
return self.final_norm(x)
if __name__ == '__main__':
# Initialize the Transformer
transformer = Transformer(dim=512, depth=6, dim_head=64, heads=8, ff_mult=4)
# Create random input tensor
input_tensor = torch.randn(1, 10, 512) # Assuming input shape is (batch_size, sequence_length, input_dim)
# Forward pass through the Transformer
output = transformer(input_tensor)
# Print the shape of the output
print(output.shape)