Update model_simple.py
Browse files- model_simple.py +308 -193
model_simple.py
CHANGED
@@ -7,15 +7,29 @@ from torch import nn, Tensor, einsum
|
|
7 |
import numpy as np
|
8 |
from dataclasses import dataclass
|
9 |
from einops import rearrange
|
10 |
-
|
11 |
-
from echoutils import *
|
12 |
-
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
13 |
-
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
14 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
15 |
dtype = torch.float32
|
16 |
warnings.filterwarnings("ignore")
|
17 |
logging.basicConfig(level=logging.ERROR)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def sinusoids(ctx, dims, max_tscale=10000):
|
20 |
assert dims % 2 == 0
|
21 |
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
@@ -40,52 +54,159 @@ def get_activation(act: str) -> nn.Module:
|
|
40 |
}
|
41 |
return act_map.get(act, nn.GELU())
|
42 |
|
43 |
-
def there_is_a(val):
|
44 |
-
return val is not None
|
45 |
-
|
46 |
@dataclass
|
47 |
class Dimensions:
|
48 |
-
|
49 |
mels: int
|
50 |
ctx: int
|
51 |
dims: int
|
52 |
head: int
|
|
|
53 |
layer: int
|
54 |
act: str
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
class rotary(nn.Module):
|
57 |
def __init__(self, dims, head):
|
58 |
super(rotary, self).__init__()
|
59 |
self.dims = dims
|
60 |
self.head = head
|
61 |
self.head_dim = dims // head
|
|
|
62 |
|
63 |
-
self.theta = nn.Parameter((torch.tensor(
|
64 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
65 |
|
66 |
def _compute_freqs_base(self):
|
67 |
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
|
68 |
return 200 * mel_scale / 1000
|
69 |
|
70 |
-
def forward(self, x) -> Tensor:
|
71 |
-
|
72 |
-
|
73 |
-
freqs =
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
class attentiona(nn.Module):
|
86 |
def __init__(self, dims: int, head: int):
|
87 |
super().__init__()
|
88 |
-
|
89 |
self.head = head
|
90 |
self.dims = dims
|
91 |
self.head_dim = dims // head
|
@@ -95,25 +216,23 @@ class attentiona(nn.Module):
|
|
95 |
self.zmax = 1e-5
|
96 |
self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False)
|
97 |
|
98 |
-
self.q = nn.Linear(dims, dims
|
99 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
100 |
-
self.out = nn.Linear(dims, dims
|
101 |
|
102 |
self.lna = nn.LayerNorm(dims)
|
|
|
103 |
self.rope = rotary(dims, head)
|
104 |
|
105 |
-
def forward(self, x, xa = None, mask = None):
|
106 |
zero = self.zero
|
107 |
|
108 |
-
q = self.q(
|
109 |
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
|
110 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
111 |
scale = q.shape[-1] ** -0.5
|
112 |
|
113 |
-
|
114 |
-
k = self.rope(k)
|
115 |
-
|
116 |
-
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
|
117 |
|
118 |
scale = torch.ones_like(k[:, :, :, 0])
|
119 |
zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5)
|
@@ -134,7 +253,7 @@ class attentiona(nn.Module):
|
|
134 |
return out
|
135 |
|
136 |
class tgate(nn.Module):
|
137 |
-
def __init__(self, dims, num_types=
|
138 |
super().__init__()
|
139 |
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
|
140 |
self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
|
@@ -145,78 +264,189 @@ class tgate(nn.Module):
|
|
145 |
return cgate
|
146 |
|
147 |
class residual(nn.Module):
|
148 |
-
def __init__(self, dims: int, head: int, act
|
149 |
super().__init__()
|
150 |
|
151 |
-
self.lna = nn.LayerNorm(dims, bias=False)
|
152 |
self.atta = attentiona(dims, head)
|
|
|
153 |
|
154 |
self.tgate = tgate(dims, num_types=1)
|
155 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
156 |
|
157 |
-
def forward(self, x: Tensor, xa = None, mask = None):
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
x = x + out
|
162 |
-
else:
|
163 |
-
x = out
|
164 |
-
if xa is not None:
|
165 |
-
x = x + self.atta(x, xa, mask=None)
|
166 |
-
|
167 |
x = x + self.tgate(x)
|
168 |
x = x + self.mlp(self.lna(x))
|
|
|
|
|
169 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
class processor(nn.Module):
|
172 |
-
def __init__(self,
|
173 |
super(processor, self).__init__()
|
174 |
|
|
|
175 |
self.ln = nn.LayerNorm(dims)
|
176 |
-
self.token = nn.Embedding(
|
177 |
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
178 |
|
179 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
180 |
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
181 |
-
|
182 |
-
act_fn = get_activation(act)
|
183 |
self.encoder = nn.Sequential(
|
184 |
-
nn.Conv1d(
|
185 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
186 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
187 |
|
188 |
-
|
|
|
|
|
|
|
189 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
190 |
self.register_buffer("mask", mask, persistent=False)
|
191 |
|
192 |
-
def
|
193 |
-
|
|
|
|
|
|
|
194 |
|
|
|
|
|
|
|
195 |
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
196 |
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
|
197 |
|
198 |
xa = self.encoder(xa).permute(0, 2, 1)
|
199 |
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
if
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
220 |
|
221 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
222 |
x = self.ln(x)
|
@@ -228,18 +458,19 @@ class Model(nn.Module):
|
|
228 |
super().__init__()
|
229 |
self.param = param
|
230 |
self.processor = processor(
|
231 |
-
|
232 |
mels=param.mels,
|
233 |
ctx=param.ctx,
|
234 |
dims=param.dims,
|
235 |
head=param.head,
|
|
|
236 |
layer=param.layer,
|
237 |
act=param.act)
|
238 |
|
239 |
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
|
240 |
|
241 |
x = input_ids
|
242 |
-
xa = pitch
|
243 |
|
244 |
enc = {}
|
245 |
if spectrogram is not None:
|
@@ -248,6 +479,8 @@ class Model(nn.Module):
|
|
248 |
enc["waveform"] = waveform
|
249 |
if pitch is not None:
|
250 |
enc["pitch"] = pitch
|
|
|
|
|
251 |
|
252 |
logits = self.processor(x, xa, enc)
|
253 |
loss = None
|
@@ -259,7 +492,7 @@ class Model(nn.Module):
|
|
259 |
def _init_weights(self, module):
|
260 |
self.init_counts = {
|
261 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
262 |
-
"Conv2d": 0, "processor": 0, "
|
263 |
for name, module in self.named_modules():
|
264 |
if isinstance(module, nn.RMSNorm):
|
265 |
nn.init.ones_(module.weight)
|
@@ -295,121 +528,3 @@ class Model(nn.Module):
|
|
295 |
for module_type, count in self.init_counts.items():
|
296 |
if count > 0:
|
297 |
print(f"{module_type}: {count}")
|
298 |
-
|
299 |
-
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
|
300 |
-
|
301 |
-
if load_saved:
|
302 |
-
if cache_dir is None:
|
303 |
-
cache_dir = cache_dir
|
304 |
-
else:
|
305 |
-
cache_dir = cache_dir
|
306 |
-
|
307 |
-
os.makedirs(cache_dir, exist_ok=True)
|
308 |
-
cache_file_train = os.path.join(cache_dir, "train.arrow")
|
309 |
-
cache_file_test = os.path.join(cache_dir, "test.arrow")
|
310 |
-
|
311 |
-
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
|
312 |
-
from datasets import Dataset
|
313 |
-
train_dataset = Dataset.load_from_disk(cache_file_train)
|
314 |
-
test_dataset = Dataset.load_from_disk(cache_file_test)
|
315 |
-
return train_dataset, test_dataset
|
316 |
-
|
317 |
-
def filter_func(x):
|
318 |
-
return (0 < len(x["transcription"]) < max_ctx and
|
319 |
-
len(x["audio"]["array"]) > 0 and
|
320 |
-
len(x["audio"]["array"]) < max_ctx * 160)
|
321 |
-
|
322 |
-
raw_train = load_dataset(
|
323 |
-
"google/fleurs", "en_us", token=token, split="train", streaming=streaming).take(1000)
|
324 |
-
raw_test = load_dataset(
|
325 |
-
"google/fleurs", "en_us", token=token, split="test", streaming=streaming).take(100)
|
326 |
-
|
327 |
-
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
328 |
-
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
329 |
-
train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
|
330 |
-
test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
|
331 |
-
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
|
332 |
-
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
|
333 |
-
return train_dataset, test_dataset
|
334 |
-
|
335 |
-
def main():
|
336 |
-
token = ""
|
337 |
-
log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
|
338 |
-
os.makedirs(log_dir, exist_ok=True)
|
339 |
-
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
340 |
-
|
341 |
-
extract_args = {
|
342 |
-
"waveform": False,
|
343 |
-
"spec": False,
|
344 |
-
"pitch_tokens": False,
|
345 |
-
"pitch": True,
|
346 |
-
"harmonics": False,
|
347 |
-
"aperiodics": False,
|
348 |
-
"phase_mod": False,
|
349 |
-
"crepe": False,
|
350 |
-
"sample_rate": 16000,
|
351 |
-
"hop_length": 256,
|
352 |
-
"mode": "mean",
|
353 |
-
"debug": False,
|
354 |
-
}
|
355 |
-
|
356 |
-
param = Dimensions(vocab=40000, mels=128, ctx=2048, dims=512, head=4, layer=4, act="swish")
|
357 |
-
|
358 |
-
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
359 |
-
load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
|
360 |
-
|
361 |
-
model = Model(param).to('cuda')
|
362 |
-
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
363 |
-
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
364 |
-
|
365 |
-
from functools import partial
|
366 |
-
metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
|
367 |
-
|
368 |
-
training_args = Seq2SeqTrainingArguments(
|
369 |
-
output_dir=log_dir,
|
370 |
-
per_device_train_batch_size=1,
|
371 |
-
per_device_eval_batch_size=1,
|
372 |
-
max_steps=1000,
|
373 |
-
eval_steps=100,
|
374 |
-
save_steps=100,
|
375 |
-
warmup_steps=10,
|
376 |
-
logging_steps=10,
|
377 |
-
logging_dir=log_dir,
|
378 |
-
logging_strategy="steps",
|
379 |
-
eval_strategy="steps",
|
380 |
-
save_strategy="no",
|
381 |
-
report_to=["tensorboard"],
|
382 |
-
push_to_hub=False,
|
383 |
-
save_total_limit=1,
|
384 |
-
label_names=["labels"],
|
385 |
-
save_safetensors=False,
|
386 |
-
eval_on_start=False,
|
387 |
-
batch_eval_metrics=False,
|
388 |
-
disable_tqdm=False,
|
389 |
-
include_tokens_per_second=True,
|
390 |
-
include_num_input_tokens_seen=True,
|
391 |
-
learning_rate=0.00025,
|
392 |
-
weight_decay=0.025,
|
393 |
-
)
|
394 |
-
|
395 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-10, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
396 |
-
|
397 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
398 |
-
|
399 |
-
trainer = Seq2SeqTrainer(
|
400 |
-
args=training_args,
|
401 |
-
model=model,
|
402 |
-
train_dataset=train_dataset,
|
403 |
-
eval_dataset=test_dataset,
|
404 |
-
data_collator=DataCollator(tokenizer=tokenizer),
|
405 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
406 |
-
compute_metrics=metrics_fn,
|
407 |
-
optimizers=(optimizer, scheduler)
|
408 |
-
)
|
409 |
-
|
410 |
-
model.init_weights()
|
411 |
-
trainer.train()
|
412 |
-
|
413 |
-
if __name__ == "__main__":
|
414 |
-
main()
|
415 |
-
|
|
|
7 |
import numpy as np
|
8 |
from dataclasses import dataclass
|
9 |
from einops import rearrange
|
10 |
+
|
|
|
|
|
|
|
11 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
12 |
dtype = torch.float32
|
13 |
warnings.filterwarnings("ignore")
|
14 |
logging.basicConfig(level=logging.ERROR)
|
15 |
|
16 |
+
def scaled_relu(x, sequence_length):
|
17 |
+
relu_output = torch.relu(x)
|
18 |
+
return relu_output / sequence_length
|
19 |
+
|
20 |
+
def taylor_softmax(x, order=2):
|
21 |
+
tapprox = 1.0
|
22 |
+
for i in range(1, order + 1):
|
23 |
+
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
24 |
+
tapprox += x**i / factorial_i
|
25 |
+
return tapprox / torch.sum(tapprox, dim=-1, keepdim=True)
|
26 |
+
|
27 |
+
def there_is_a(a):
|
28 |
+
return a is not None
|
29 |
+
|
30 |
+
def AorB(a, b):
|
31 |
+
return a if there_is_a(a) else b
|
32 |
+
|
33 |
def sinusoids(ctx, dims, max_tscale=10000):
|
34 |
assert dims % 2 == 0
|
35 |
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
|
|
54 |
}
|
55 |
return act_map.get(act, nn.GELU())
|
56 |
|
|
|
|
|
|
|
57 |
@dataclass
|
58 |
class Dimensions:
|
59 |
+
tokens: int
|
60 |
mels: int
|
61 |
ctx: int
|
62 |
dims: int
|
63 |
head: int
|
64 |
+
head_dim: int
|
65 |
layer: int
|
66 |
act: str
|
67 |
|
68 |
+
def vectorized_taylor_sine(x, order=5):
|
69 |
+
original_shape = x.shape
|
70 |
+
x = x.flatten(0, -2)
|
71 |
+
exponents = torch.arange(1, order + 1, 2, device=x.device, dtype=torch.float32)
|
72 |
+
x_powers = x.unsqueeze(-1) ** exponents
|
73 |
+
factorials = torch.exp(torch.lgamma(exponents + 1))
|
74 |
+
signs = (-1)**(torch.arange(0, len(exponents), device=x.device, dtype=torch.float32))
|
75 |
+
terms = signs * x_powers / factorials
|
76 |
+
result = terms.sum(dim=-1)
|
77 |
+
return result.view(original_shape)
|
78 |
+
|
79 |
+
def vectorized_taylor_cosine(x, order=5):
|
80 |
+
original_shape = x.shape
|
81 |
+
x = x.flatten(0, -2)
|
82 |
+
exponents = torch.arange(0, order + 1, 2, device=x.device, dtype=torch.float32)
|
83 |
+
x_powers = x.unsqueeze(-1) ** exponents
|
84 |
+
factorials = torch.exp(torch.lgamma(exponents + 1))
|
85 |
+
signs = (-1)**(torch.arange(0, len(exponents), device=x.device, dtype=torch.float32))
|
86 |
+
terms = signs * x_powers / factorials
|
87 |
+
result = terms.sum(dim=-1)
|
88 |
+
return result.view(original_shape)
|
89 |
+
|
90 |
class rotary(nn.Module):
|
91 |
def __init__(self, dims, head):
|
92 |
super(rotary, self).__init__()
|
93 |
self.dims = dims
|
94 |
self.head = head
|
95 |
self.head_dim = dims // head
|
96 |
+
self.taylor_order = 10
|
97 |
|
98 |
+
self.theta = nn.Parameter((torch.tensor(360000, device=device, dtype=dtype)), requires_grad=False)
|
99 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
100 |
|
101 |
def _compute_freqs_base(self):
|
102 |
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
|
103 |
return 200 * mel_scale / 1000
|
104 |
|
105 |
+
def forward(self, x) -> torch.Tensor:
|
106 |
+
positions = (torch.arange(0, x.shape[2], device=x.device))
|
107 |
+
freqs = (self.theta / 220.0) * self.freqs_base
|
108 |
+
freqs = positions[:, None] * freqs
|
109 |
+
freqs_rescaled = (freqs + torch.pi) % (2 * torch.pi) - torch.pi
|
110 |
+
|
111 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
112 |
+
cos = vectorized_taylor_cosine(freqs_rescaled, order=self.taylor_order)
|
113 |
+
sin = vectorized_taylor_sine(freqs_rescaled, order=self.taylor_order)
|
114 |
+
rotary_dim = cos.shape[-1]
|
115 |
+
x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
|
116 |
+
x_embed = (x_rot * cos) + (rotate_half(x_rot) * sin)
|
117 |
+
x_embed = torch.cat([x_embed, x_pass], dim=-1)
|
118 |
+
return x_embed.type_as(x)
|
119 |
+
|
120 |
+
def taylor_sine(x, order=5):
|
121 |
+
result = torch.zeros_like(x)
|
122 |
+
for i in range(order + 1):
|
123 |
+
if i % 2 == 1:
|
124 |
+
term = x**i / torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
125 |
+
if (i // 2) % 2 == 1:
|
126 |
+
result -= term
|
127 |
+
else:
|
128 |
+
result += term
|
129 |
+
return result
|
130 |
+
|
131 |
+
def taylor_cosine(x, order=5):
|
132 |
+
result = torch.zeros_like(x)
|
133 |
+
for i in range(order + 1):
|
134 |
+
if i % 2 == 0:
|
135 |
+
term = x**i / torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
136 |
+
if (i // 2) % 2 == 1:
|
137 |
+
result -= term
|
138 |
+
else:
|
139 |
+
result += term
|
140 |
+
return result
|
141 |
+
|
142 |
+
class rotarya(nn.Module):
|
143 |
+
def __init__(self, dims, head):
|
144 |
+
super(rotary, self).__init__()
|
145 |
+
self.dims = dims
|
146 |
+
self.head = head
|
147 |
+
self.head_dim = dims // head
|
148 |
+
self.taylor_order = 5
|
149 |
+
|
150 |
+
self.theta = nn.Parameter((torch.tensor(1600, device=device, dtype=dtype)), requires_grad=False)
|
151 |
+
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
152 |
+
|
153 |
+
def _compute_freqs_base(self):
|
154 |
+
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
|
155 |
+
return 200 * mel_scale / 1000
|
156 |
+
|
157 |
+
def forward(self, x) -> torch.Tensor:
|
158 |
+
|
159 |
+
positions = (torch.arange(0, x.shape[2], device=x.device))
|
160 |
+
freqs = (self.theta / 220.0) * self.freqs_base
|
161 |
+
freqs = positions[:, None] * freqs
|
162 |
+
freqs = (freqs + torch.pi) % (2 * torch.pi) - torch.pi
|
163 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
164 |
+
cos = taylor_cosine(freqs, order=self.taylor_order)
|
165 |
+
sin = taylor_sine(freqs, order=self.taylor_order)
|
166 |
+
rotary_dim = cos.shape[-1]
|
167 |
+
x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
|
168 |
+
x_embed = (x_rot * cos) + (rotate_half(x_rot) * sin)
|
169 |
+
x_embed = torch.cat([x_embed, x_pass], dim=-1)
|
170 |
+
return x_embed.type_as(x)
|
171 |
+
|
172 |
+
def rotate_half(x):
|
173 |
+
x1 = x[..., : x.shape[-1] // 2]
|
174 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
175 |
+
return torch.cat((-x2, x1), dim=-1)
|
176 |
+
|
177 |
+
# class rotary(nn.Module):
|
178 |
+
# def __init__(self, dims, head):
|
179 |
+
# super(rotary, self).__init__()
|
180 |
+
# self.dims = dims
|
181 |
+
# self.head = head
|
182 |
+
# self.head_dim = dims // head
|
183 |
+
|
184 |
+
# self.theta = nn.Parameter((torch.tensor(1600, device=device, dtype=dtype)), requires_grad=False)
|
185 |
+
# # self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
186 |
+
|
187 |
+
# def _compute_freqs_base(self):
|
188 |
+
# mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
|
189 |
+
# return 200 * mel_scale / 1000
|
190 |
+
|
191 |
+
# def forward(self, x) -> Tensor:
|
192 |
+
# positions = (torch.arange(0, x.shape[2], device=x.device))
|
193 |
+
# freqs = (self.theta / 220.0) * self._compute_freqs_base()
|
194 |
+
# freqs = positions[:, None] * freqs
|
195 |
+
|
196 |
+
# with torch.autocast(device_type="cuda", enabled=False):
|
197 |
+
# freqs = torch.polar(torch.ones_like(freqs), freqs)
|
198 |
+
# x1 = x[..., :freqs.shape[-1]*2]
|
199 |
+
# x2 = x[..., freqs.shape[-1]*2:]
|
200 |
+
# orig_shape = x1.shape
|
201 |
+
# x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
202 |
+
# x1 = torch.view_as_complex(x1) * freqs
|
203 |
+
# x1 = torch.view_as_real(x1).flatten(-2)
|
204 |
+
# x1 = x1.view(orig_shape)
|
205 |
+
# return torch.cat([x1.type_as(x), x2], dim=-1)
|
206 |
|
207 |
class attentiona(nn.Module):
|
208 |
def __init__(self, dims: int, head: int):
|
209 |
super().__init__()
|
|
|
210 |
self.head = head
|
211 |
self.dims = dims
|
212 |
self.head_dim = dims // head
|
|
|
216 |
self.zmax = 1e-5
|
217 |
self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False)
|
218 |
|
219 |
+
self.q = nn.Linear(dims, dims)
|
220 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
221 |
+
self.out = nn.Linear(dims, dims)
|
222 |
|
223 |
self.lna = nn.LayerNorm(dims)
|
224 |
+
self.lnb = nn.LayerNorm(dims // head)
|
225 |
self.rope = rotary(dims, head)
|
226 |
|
227 |
+
def forward(self, x, xa = None, mask = None, positions = None):
|
228 |
zero = self.zero
|
229 |
|
230 |
+
q = self.q(x)
|
231 |
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
|
232 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
233 |
scale = q.shape[-1] ** -0.5
|
234 |
|
235 |
+
qk = einsum('b h k d, b h q d -> b h k q', self.lnb(q), self.lnb(k)) * scale
|
|
|
|
|
|
|
236 |
|
237 |
scale = torch.ones_like(k[:, :, :, 0])
|
238 |
zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5)
|
|
|
253 |
return out
|
254 |
|
255 |
class tgate(nn.Module):
|
256 |
+
def __init__(self, dims, num_types=1):
|
257 |
super().__init__()
|
258 |
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
|
259 |
self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
|
|
|
264 |
return cgate
|
265 |
|
266 |
class residual(nn.Module):
|
267 |
+
def __init__(self, dims: int, head: int, layer = 2, act = "silu"):
|
268 |
super().__init__()
|
269 |
|
270 |
+
self.lna = nn.LayerNorm(dims, bias=False)
|
271 |
self.atta = attentiona(dims, head)
|
272 |
+
self.dsl = skip_layer(dims, head, layer=2)
|
273 |
|
274 |
self.tgate = tgate(dims, num_types=1)
|
275 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
276 |
|
277 |
+
def forward(self, x: Tensor, xa = None, mask = None, positions=None):
|
278 |
+
# log = {}
|
279 |
+
x = x + self.atta(self.lna(x), xa=xa, mask=mask)
|
280 |
+
x, _ = self.dsl(self.lna(x), xa=xa, mask=mask) # _ outputs logs for jumps
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
x = x + self.tgate(x)
|
282 |
x = x + self.mlp(self.lna(x))
|
283 |
+
# print(results['jumps'])
|
284 |
+
# log['jumps'] = l
|
285 |
return x
|
286 |
+
|
287 |
+
class skip_layer(nn.Module):
|
288 |
+
def __init__(self, dims, head, layer, threshold=0.1):
|
289 |
+
super().__init__()
|
290 |
+
self.layers = nn.ModuleList()
|
291 |
+
self.layer = layer
|
292 |
|
293 |
+
self.threshold = threshold
|
294 |
+
self.dims = dims
|
295 |
+
self.head = head
|
296 |
+
self.head_dim = dims // head
|
297 |
+
|
298 |
+
self.attention_module = attentiona(dims, head)
|
299 |
+
self.node_predictors = nn.ModuleList([
|
300 |
+
nn.Sequential(
|
301 |
+
nn.LayerNorm(dims),
|
302 |
+
nn.Linear(dims, 1),
|
303 |
+
nn.Sigmoid()
|
304 |
+
) for _ in range(layer)
|
305 |
+
])
|
306 |
+
|
307 |
+
for i in range(layer):
|
308 |
+
self.layers.append(nn.ModuleDict({
|
309 |
+
'ln': nn.LayerNorm(dims),
|
310 |
+
'gate': nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()),
|
311 |
+
'adapter': nn.Linear(dims, dims) if i % 2 == 0 else None
|
312 |
+
}))
|
313 |
+
|
314 |
+
self.policy_net = nn.Sequential(
|
315 |
+
nn.Linear(dims, 128),
|
316 |
+
nn.ReLU(),
|
317 |
+
nn.Linear(128, 3))
|
318 |
+
|
319 |
+
self.jump_weights = nn.Parameter(torch.tensor([0.1, 0.05, 0.01]))
|
320 |
+
|
321 |
+
n_mlp = dims * 4
|
322 |
+
self.mlp_gate = nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid())
|
323 |
+
self.mlp = nn.Sequential(nn.Linear(dims, n_mlp), nn.GELU(), nn.Linear(n_mlp, dims))
|
324 |
+
self.mlp_ln =nn.LayerNorm(dims)
|
325 |
+
self.working_memory = nn.Parameter(torch.zeros(1, 1, dims))
|
326 |
+
self.memory_gate = nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid())
|
327 |
+
|
328 |
+
def _calculate_shared_attention(self, x, mask=None):
|
329 |
+
return self.attention_module(x, xa=x, mask=None)
|
330 |
+
|
331 |
+
def predict_node_importance(self, x, layer_idx):
|
332 |
+
importance = self.node_predictors[layer_idx](x)
|
333 |
+
return (importance > self.threshold).float()
|
334 |
+
|
335 |
+
def forward(self, x, xa=None, mask=None):
|
336 |
+
batch, ctx = x.shape[:2]
|
337 |
+
|
338 |
+
working_memory = self.working_memory.expand(batch, -1, -1)
|
339 |
+
original_x = x
|
340 |
+
pooled_representation = x.mean(dim=1)
|
341 |
+
policy_logits = self.policy_net(pooled_representation)
|
342 |
+
policy = F.softmax(policy_logits, dim=-1)
|
343 |
+
|
344 |
+
jump_history = []
|
345 |
+
i = 0
|
346 |
+
while i < self.layer:
|
347 |
+
layer = self.layers[i]
|
348 |
+
node_importance = self.predict_node_importance(x, i)
|
349 |
+
if node_importance.mean() < 0.2 and i > 0:
|
350 |
+
i += 1
|
351 |
+
jump_history.append(i)
|
352 |
+
continue
|
353 |
+
|
354 |
+
norm_x = layer['ln'](x)
|
355 |
+
importance_mask_base = node_importance.unsqueeze(1).contiguous()
|
356 |
+
combined_custom_mask = None
|
357 |
+
if mask is None:
|
358 |
+
combined_custom_mask = importance_mask_base
|
359 |
+
else:
|
360 |
+
combined_custom_mask = mask.contiguous() * importance_mask_base
|
361 |
+
|
362 |
+
if node_importance.mean() > 0.3:
|
363 |
+
attn_output = self._calculate_shared_attention(norm_x, mask=combined_custom_mask.contiguous())
|
364 |
+
if layer['adapter'] is not None:
|
365 |
+
attn_output = layer['adapter'](attn_output)
|
366 |
+
|
367 |
+
gate_value = layer['gate'](norm_x)
|
368 |
+
x = x + gate_value * attn_output
|
369 |
+
memory_gate = self.memory_gate(x)
|
370 |
+
working_memory = memory_gate * working_memory + (1 - memory_gate) * x.mean(dim=1, keepdim=True)
|
371 |
+
|
372 |
+
jump_prob = policy[:, 1] if i < self.layer - 1 else torch.zeros_like(policy[:, 1])
|
373 |
+
should_jump = (torch.rand_like(jump_prob) < jump_prob).any()
|
374 |
+
|
375 |
+
if should_jump:
|
376 |
+
jump_length = torch.multinomial(policy, 1)[:, 0].max().item() + 1
|
377 |
+
i_next = min(i + jump_length, self.layer - 1)
|
378 |
+
skip_weight = self.jump_weights[min(jump_length-1, 2)]
|
379 |
+
x = x + skip_weight * original_x + (1-skip_weight) * working_memory
|
380 |
+
i = i_next
|
381 |
+
jump_history.append(i)
|
382 |
+
else:
|
383 |
+
i += 1
|
384 |
+
|
385 |
+
mlp_importance = self.mlp_gate(x)
|
386 |
+
mlp_output = self.mlp(self.mlp_ln(x))
|
387 |
+
x = x + mlp_importance * mlp_output
|
388 |
+
return x, {'jumps': jump_history}
|
389 |
+
|
390 |
class processor(nn.Module):
|
391 |
+
def __init__(self, tokens, mels, ctx, dims, head, head_dim, layer, act):
|
392 |
super(processor, self).__init__()
|
393 |
|
394 |
+
act_fn = get_activation(act)
|
395 |
self.ln = nn.LayerNorm(dims)
|
396 |
+
self.token = nn.Embedding(tokens, dims)
|
397 |
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
398 |
|
399 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
400 |
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
401 |
+
|
|
|
402 |
self.encoder = nn.Sequential(
|
403 |
+
nn.Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
404 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
405 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
406 |
|
407 |
+
modal = False
|
408 |
+
self.block = nn.ModuleList([residual(dims, head, layer, act_fn) for _ in range(layer)]) if modal else None
|
409 |
+
|
410 |
+
self.res = residual(dims, head, layer, act_fn)
|
411 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
412 |
self.register_buffer("mask", mask, persistent=False)
|
413 |
|
414 |
+
def init_memory(self, batch):
|
415 |
+
return torch.zeros(batch, 1, self.dims).to(next(self.parameters()).device)
|
416 |
+
|
417 |
+
def update_memory(self, x, working_memory):
|
418 |
+
return (x + working_memory) / 2
|
419 |
|
420 |
+
def forward(self, x, xa, enc=None, sequential=False, modal=False, blend=False, kv_cache=None) -> Tensor:
|
421 |
+
|
422 |
+
mask = self.mask[:x.shape[1], :x.shape[1]]
|
423 |
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
424 |
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
|
425 |
|
426 |
xa = self.encoder(xa).permute(0, 2, 1)
|
427 |
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
428 |
|
429 |
+
xa = self.res(xa, None, None)
|
430 |
+
x = self.res(x, None, mask)
|
431 |
+
x = self.res(x, xa, None)
|
432 |
+
|
433 |
+
if blend:
|
434 |
+
if sequential:
|
435 |
+
y = x
|
436 |
+
else:
|
437 |
+
a = torch.sigmoid(self.blend)
|
438 |
+
x = a * x + (1 - a) * y
|
439 |
+
|
440 |
+
if modal:
|
441 |
+
for block in chain(self.block or []):
|
442 |
+
xm = block(torch.cat([x, xa], dim=1), mask=mask) if modal else None
|
443 |
+
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
|
444 |
+
if blend:
|
445 |
+
if sequential:
|
446 |
+
y = x
|
447 |
+
else:
|
448 |
+
a = torch.sigmoid(self.blend)
|
449 |
+
x = a * x + (1 - a) * y
|
450 |
|
451 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
452 |
x = self.ln(x)
|
|
|
458 |
super().__init__()
|
459 |
self.param = param
|
460 |
self.processor = processor(
|
461 |
+
tokens=param.tokens,
|
462 |
mels=param.mels,
|
463 |
ctx=param.ctx,
|
464 |
dims=param.dims,
|
465 |
head=param.head,
|
466 |
+
head_dim=param.head_dim,
|
467 |
layer=param.layer,
|
468 |
act=param.act)
|
469 |
|
470 |
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
|
471 |
|
472 |
x = input_ids
|
473 |
+
xa = AorB(pitch, spectrogram)
|
474 |
|
475 |
enc = {}
|
476 |
if spectrogram is not None:
|
|
|
479 |
enc["waveform"] = waveform
|
480 |
if pitch is not None:
|
481 |
enc["pitch"] = pitch
|
482 |
+
if pitch_tokens is not None:
|
483 |
+
enc["ptokens"] = pitch_tokens
|
484 |
|
485 |
logits = self.processor(x, xa, enc)
|
486 |
loss = None
|
|
|
492 |
def _init_weights(self, module):
|
493 |
self.init_counts = {
|
494 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
495 |
+
"Conv2d": 0, "processor": 0, "attentiona": 0, "Residual": 0}
|
496 |
for name, module in self.named_modules():
|
497 |
if isinstance(module, nn.RMSNorm):
|
498 |
nn.init.ones_(module.weight)
|
|
|
528 |
for module_type, count in self.init_counts.items():
|
529 |
if count > 0:
|
530 |
print(f"{module_type}: {count}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|