|
|
|
import warnings |
|
import logging |
|
from itertools import chain |
|
import torch |
|
from torch import nn, Tensor, einsum |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from einops import rearrange |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
dtype = torch.float32 |
|
warnings.filterwarnings("ignore") |
|
logging.basicConfig(level=logging.ERROR) |
|
|
|
def scaled_relu(x, sequence_length): |
|
relu_output = torch.relu(x) |
|
return relu_output / sequence_length |
|
|
|
def taylor_softmax(x, order=2): |
|
tapprox = 1.0 |
|
for i in range(1, order + 1): |
|
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) |
|
tapprox += x**i / factorial_i |
|
return tapprox / torch.sum(tapprox, dim=-1, keepdim=True) |
|
|
|
def there_is_a(a): |
|
return a is not None |
|
|
|
def AorB(a, b): |
|
return a if there_is_a(a) else b |
|
|
|
def sinusoids(ctx, dims, max_tscale=10000): |
|
assert dims % 2 == 0 |
|
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1) |
|
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32)) |
|
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0) |
|
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1) |
|
positional_embedding = nn.Parameter(position, requires_grad=True) |
|
return positional_embedding |
|
|
|
def get_activation(act: str) -> nn.Module: |
|
act_map = { |
|
"gelu": nn.GELU(), |
|
"relu": nn.ReLU(), |
|
"sigmoid": nn.Sigmoid(), |
|
"tanh": nn.Tanh(), |
|
"swish": nn.SiLU(), |
|
"tanhshrink": nn.Tanhshrink(), |
|
"softplus": nn.Softplus(), |
|
"softshrink": nn.Softshrink(), |
|
"leaky_relu": nn.LeakyReLU(), |
|
"elu": nn.ELU() |
|
} |
|
return act_map.get(act, nn.GELU()) |
|
|
|
@dataclass |
|
class Dimensions: |
|
tokens: int |
|
mels: int |
|
ctx: int |
|
dims: int |
|
head: int |
|
head_dim: int |
|
layer: int |
|
act: str |
|
|
|
def vectorized_taylor_sine(x, order=5): |
|
original_shape = x.shape |
|
x = x.flatten(0, -2) |
|
exponents = torch.arange(1, order + 1, 2, device=x.device, dtype=torch.float32) |
|
x_powers = x.unsqueeze(-1) ** exponents |
|
factorials = torch.exp(torch.lgamma(exponents + 1)) |
|
signs = (-1)**(torch.arange(0, len(exponents), device=x.device, dtype=torch.float32)) |
|
terms = signs * x_powers / factorials |
|
result = terms.sum(dim=-1) |
|
return result.view(original_shape) |
|
|
|
def vectorized_taylor_cosine(x, order=5): |
|
original_shape = x.shape |
|
x = x.flatten(0, -2) |
|
exponents = torch.arange(0, order + 1, 2, device=x.device, dtype=torch.float32) |
|
x_powers = x.unsqueeze(-1) ** exponents |
|
factorials = torch.exp(torch.lgamma(exponents + 1)) |
|
signs = (-1)**(torch.arange(0, len(exponents), device=x.device, dtype=torch.float32)) |
|
terms = signs * x_powers / factorials |
|
result = terms.sum(dim=-1) |
|
return result.view(original_shape) |
|
|
|
class rotary(nn.Module): |
|
def __init__(self, dims, head): |
|
super(rotary, self).__init__() |
|
self.dims = dims |
|
self.head = head |
|
self.head_dim = dims // head |
|
self.taylor_order = 10 |
|
|
|
self.theta = nn.Parameter((torch.tensor(360000, device=device, dtype=dtype)), requires_grad=False) |
|
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False) |
|
|
|
def _compute_freqs_base(self): |
|
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1 |
|
return 200 * mel_scale / 1000 |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
positions = (torch.arange(0, x.shape[2], device=x.device)) |
|
freqs = (self.theta / 220.0) * self.freqs_base |
|
freqs = positions[:, None] * freqs |
|
freqs_rescaled = (freqs + torch.pi) % (2 * torch.pi) - torch.pi |
|
|
|
with torch.autocast(device_type="cuda", enabled=False): |
|
cos = vectorized_taylor_cosine(freqs_rescaled, order=self.taylor_order) |
|
sin = vectorized_taylor_sine(freqs_rescaled, order=self.taylor_order) |
|
rotary_dim = cos.shape[-1] |
|
x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] |
|
x_embed = (x_rot * cos) + (rotate_half(x_rot) * sin) |
|
x_embed = torch.cat([x_embed, x_pass], dim=-1) |
|
return x_embed.type_as(x) |
|
|
|
def taylor_sine(x, order=5): |
|
result = torch.zeros_like(x) |
|
for i in range(order + 1): |
|
if i % 2 == 1: |
|
term = x**i / torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) |
|
if (i // 2) % 2 == 1: |
|
result -= term |
|
else: |
|
result += term |
|
return result |
|
|
|
def taylor_cosine(x, order=5): |
|
result = torch.zeros_like(x) |
|
for i in range(order + 1): |
|
if i % 2 == 0: |
|
term = x**i / torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) |
|
if (i // 2) % 2 == 1: |
|
result -= term |
|
else: |
|
result += term |
|
return result |
|
|
|
class rotarya(nn.Module): |
|
def __init__(self, dims, head): |
|
super(rotary, self).__init__() |
|
self.dims = dims |
|
self.head = head |
|
self.head_dim = dims // head |
|
self.taylor_order = 5 |
|
|
|
self.theta = nn.Parameter((torch.tensor(1600, device=device, dtype=dtype)), requires_grad=False) |
|
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False) |
|
|
|
def _compute_freqs_base(self): |
|
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1 |
|
return 200 * mel_scale / 1000 |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
|
|
positions = (torch.arange(0, x.shape[2], device=x.device)) |
|
freqs = (self.theta / 220.0) * self.freqs_base |
|
freqs = positions[:, None] * freqs |
|
freqs = (freqs + torch.pi) % (2 * torch.pi) - torch.pi |
|
with torch.autocast(device_type="cuda", enabled=False): |
|
cos = taylor_cosine(freqs, order=self.taylor_order) |
|
sin = taylor_sine(freqs, order=self.taylor_order) |
|
rotary_dim = cos.shape[-1] |
|
x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] |
|
x_embed = (x_rot * cos) + (rotate_half(x_rot) * sin) |
|
x_embed = torch.cat([x_embed, x_pass], dim=-1) |
|
return x_embed.type_as(x) |
|
|
|
def rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class attentiona(nn.Module): |
|
def __init__(self, dims: int, head: int): |
|
super().__init__() |
|
self.head = head |
|
self.dims = dims |
|
self.head_dim = dims // head |
|
|
|
self.pad_token = 0 |
|
self.zmin = 1e-6 |
|
self.zmax = 1e-5 |
|
self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False) |
|
|
|
self.q = nn.Linear(dims, dims) |
|
self.kv = nn.Linear(dims, dims * 2, bias=False) |
|
self.out = nn.Linear(dims, dims) |
|
|
|
self.lna = nn.LayerNorm(dims) |
|
self.lnb = nn.LayerNorm(dims // head) |
|
self.rope = rotary(dims, head) |
|
|
|
def forward(self, x, xa = None, mask = None, positions = None): |
|
zero = self.zero |
|
|
|
q = self.q(x) |
|
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v)) |
|
scale = q.shape[-1] ** -0.5 |
|
|
|
qk = einsum('b h k d, b h q d -> b h k q', self.lnb(q), self.lnb(k)) * scale |
|
|
|
scale = torch.ones_like(k[:, :, :, 0]) |
|
zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5) |
|
scale[k[:, :, :, 0].float() == 0] = zero |
|
|
|
if there_is_a(mask): |
|
i, j = qk.shape[-2:] |
|
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1) |
|
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max) * scale.unsqueeze(-2).expand(qk.shape) |
|
qk = F.sigmoid(qk) |
|
|
|
qk = qk * scale.unsqueeze(-2) |
|
qk = taylor_softmax(qk, order=2) |
|
|
|
wv = einsum('b h k q, b h q d -> b h k d', qk, v) |
|
wv = rearrange(wv, 'b h c d -> b c (h d)') |
|
out = self.out(wv) |
|
return out |
|
|
|
class tgate(nn.Module): |
|
def __init__(self, dims, num_types=1): |
|
super().__init__() |
|
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)]) |
|
self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1)) |
|
def forward(self, x): |
|
types = self.classifier(x) |
|
gates = torch.stack([gate(x) for gate in self.gates], dim=-1) |
|
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1) |
|
return cgate |
|
|
|
class residual(nn.Module): |
|
def __init__(self, dims: int, head: int, layer = 2, act = "silu"): |
|
super().__init__() |
|
|
|
self.lna = nn.LayerNorm(dims, bias=False) |
|
self.atta = attentiona(dims, head) |
|
self.dsl = skip_layer(dims, head, layer=2) |
|
|
|
self.tgate = tgate(dims, num_types=1) |
|
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims)) |
|
|
|
def forward(self, x: Tensor, xa = None, mask = None, positions=None): |
|
|
|
x = x + self.atta(self.lna(x), xa=xa, mask=mask) |
|
x, _ = self.dsl(self.lna(x), xa=xa, mask=mask) |
|
x = x + self.tgate(x) |
|
x = x + self.mlp(self.lna(x)) |
|
|
|
|
|
return x |
|
|
|
class skip_layer(nn.Module): |
|
def __init__(self, dims, head, layer, threshold=0.1): |
|
super().__init__() |
|
self.layers = nn.ModuleList() |
|
self.layer = layer |
|
|
|
self.threshold = threshold |
|
self.dims = dims |
|
self.head = head |
|
self.head_dim = dims // head |
|
|
|
self.attention_module = attentiona(dims, head) |
|
self.node_predictors = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.LayerNorm(dims), |
|
nn.Linear(dims, 1), |
|
nn.Sigmoid() |
|
) for _ in range(layer) |
|
]) |
|
|
|
for i in range(layer): |
|
self.layers.append(nn.ModuleDict({ |
|
'ln': nn.LayerNorm(dims), |
|
'gate': nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()), |
|
'adapter': nn.Linear(dims, dims) if i % 2 == 0 else None |
|
})) |
|
|
|
self.policy_net = nn.Sequential( |
|
nn.Linear(dims, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 3)) |
|
|
|
self.jump_weights = nn.Parameter(torch.tensor([0.1, 0.05, 0.01])) |
|
|
|
n_mlp = dims * 4 |
|
self.mlp_gate = nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()) |
|
self.mlp = nn.Sequential(nn.Linear(dims, n_mlp), nn.GELU(), nn.Linear(n_mlp, dims)) |
|
self.mlp_ln =nn.LayerNorm(dims) |
|
self.working_memory = nn.Parameter(torch.zeros(1, 1, dims)) |
|
self.memory_gate = nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()) |
|
|
|
def _calculate_shared_attention(self, x, mask=None): |
|
return self.attention_module(x, xa=x, mask=None) |
|
|
|
def predict_node_importance(self, x, layer_idx): |
|
importance = self.node_predictors[layer_idx](x) |
|
return (importance > self.threshold).float() |
|
|
|
def forward(self, x, xa=None, mask=None): |
|
batch, ctx = x.shape[:2] |
|
|
|
working_memory = self.working_memory.expand(batch, -1, -1) |
|
original_x = x |
|
pooled_representation = x.mean(dim=1) |
|
policy_logits = self.policy_net(pooled_representation) |
|
policy = F.softmax(policy_logits, dim=-1) |
|
|
|
jump_history = [] |
|
i = 0 |
|
while i < self.layer: |
|
layer = self.layers[i] |
|
node_importance = self.predict_node_importance(x, i) |
|
if node_importance.mean() < 0.2 and i > 0: |
|
i += 1 |
|
jump_history.append(i) |
|
continue |
|
|
|
norm_x = layer['ln'](x) |
|
importance_mask_base = node_importance.unsqueeze(1).contiguous() |
|
combined_custom_mask = None |
|
if mask is None: |
|
combined_custom_mask = importance_mask_base |
|
else: |
|
combined_custom_mask = mask.contiguous() * importance_mask_base |
|
|
|
if node_importance.mean() > 0.3: |
|
attn_output = self._calculate_shared_attention(norm_x, mask=combined_custom_mask.contiguous()) |
|
if layer['adapter'] is not None: |
|
attn_output = layer['adapter'](attn_output) |
|
|
|
gate_value = layer['gate'](norm_x) |
|
x = x + gate_value * attn_output |
|
memory_gate = self.memory_gate(x) |
|
working_memory = memory_gate * working_memory + (1 - memory_gate) * x.mean(dim=1, keepdim=True) |
|
|
|
jump_prob = policy[:, 1] if i < self.layer - 1 else torch.zeros_like(policy[:, 1]) |
|
should_jump = (torch.rand_like(jump_prob) < jump_prob).any() |
|
|
|
if should_jump: |
|
jump_length = torch.multinomial(policy, 1)[:, 0].max().item() + 1 |
|
i_next = min(i + jump_length, self.layer - 1) |
|
skip_weight = self.jump_weights[min(jump_length-1, 2)] |
|
x = x + skip_weight * original_x + (1-skip_weight) * working_memory |
|
i = i_next |
|
jump_history.append(i) |
|
else: |
|
i += 1 |
|
|
|
mlp_importance = self.mlp_gate(x) |
|
mlp_output = self.mlp(self.mlp_ln(x)) |
|
x = x + mlp_importance * mlp_output |
|
return x, {'jumps': jump_history} |
|
|
|
class processor(nn.Module): |
|
def __init__(self, tokens, mels, ctx, dims, head, head_dim, layer, act): |
|
super(processor, self).__init__() |
|
|
|
act_fn = get_activation(act) |
|
self.ln = nn.LayerNorm(dims) |
|
self.token = nn.Embedding(tokens, dims) |
|
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) |
|
|
|
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True) |
|
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True) |
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn, |
|
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn, |
|
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) |
|
|
|
modal = False |
|
self.block = nn.ModuleList([residual(dims, head, layer, act_fn) for _ in range(layer)]) if modal else None |
|
|
|
self.res = residual(dims, head, layer, act_fn) |
|
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1) |
|
self.register_buffer("mask", mask, persistent=False) |
|
|
|
def init_memory(self, batch): |
|
return torch.zeros(batch, 1, self.dims).to(next(self.parameters()).device) |
|
|
|
def update_memory(self, x, working_memory): |
|
return (x + working_memory) / 2 |
|
|
|
def forward(self, x, xa, enc=None, sequential=False, modal=False, blend=False, kv_cache=None) -> Tensor: |
|
|
|
mask = self.mask[:x.shape[1], :x.shape[1]] |
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
|
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]]) |
|
|
|
xa = self.encoder(xa).permute(0, 2, 1) |
|
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype) |
|
|
|
xa = self.res(xa, None, None) |
|
x = self.res(x, None, mask) |
|
x = self.res(x, xa, None) |
|
|
|
if blend: |
|
if sequential: |
|
y = x |
|
else: |
|
a = torch.sigmoid(self.blend) |
|
x = a * x + (1 - a) * y |
|
|
|
if modal: |
|
for block in chain(self.block or []): |
|
xm = block(torch.cat([x, xa], dim=1), mask=mask) if modal else None |
|
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x |
|
if blend: |
|
if sequential: |
|
y = x |
|
else: |
|
a = torch.sigmoid(self.blend) |
|
x = a * x + (1 - a) * y |
|
|
|
x = nn.functional.dropout(x, p=0.001, training=self.training) |
|
x = self.ln(x) |
|
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float() |
|
return x |
|
|
|
class Model(nn.Module): |
|
def __init__(self, param: Dimensions): |
|
super().__init__() |
|
self.param = param |
|
self.processor = processor( |
|
tokens=param.tokens, |
|
mels=param.mels, |
|
ctx=param.ctx, |
|
dims=param.dims, |
|
head=param.head, |
|
head_dim=param.head_dim, |
|
layer=param.layer, |
|
act=param.act) |
|
|
|
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None): |
|
|
|
x = input_ids |
|
xa = AorB(pitch, spectrogram) |
|
|
|
enc = {} |
|
if spectrogram is not None: |
|
enc["spectrogram"] = spectrogram |
|
if waveform is not None: |
|
enc["waveform"] = waveform |
|
if pitch is not None: |
|
enc["pitch"] = pitch |
|
if pitch_tokens is not None: |
|
enc["ptokens"] = pitch_tokens |
|
|
|
logits = self.processor(x, xa, enc) |
|
loss = None |
|
if labels is not None: |
|
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0) |
|
|
|
return {"logits": logits, "loss": loss} |
|
|
|
def _init_weights(self, module): |
|
self.init_counts = { |
|
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0, |
|
"Conv2d": 0, "processor": 0, "attentiona": 0, "Residual": 0} |
|
for name, module in self.named_modules(): |
|
if isinstance(module, nn.RMSNorm): |
|
nn.init.ones_(module.weight) |
|
self.init_counts["RMSNorm"] += 1 |
|
if isinstance(module, nn.LayerNorm): |
|
nn.init.ones_(module.weight) |
|
self.init_counts["LayerNorm"] += 1 |
|
elif isinstance(module, nn.Linear): |
|
if module.weight is not None: |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
self.init_counts["Linear"] += 1 |
|
elif isinstance(module, nn.Conv1d): |
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
self.init_counts["Conv1d"] += 1 |
|
elif isinstance(module, nn.Conv2d): |
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
self.init_counts["Conv2d"] += 1 |
|
elif isinstance(module, residual): |
|
self.init_counts["Residual"] += 1 |
|
elif isinstance(module, processor): |
|
self.init_counts["processor"] += 1 |
|
|
|
def init_weights(self): |
|
print("Initializing model weights...") |
|
self.apply(self._init_weights) |
|
print("Initialization summary:") |
|
for module_type, count in self.init_counts.items(): |
|
if count > 0: |
|
print(f"{module_type}: {count}") |
|
|