Update model_simple.py
Browse files- model_simple.py +222 -91
model_simple.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
|
2 |
import warnings
|
|
|
3 |
import logging
|
4 |
from itertools import chain
|
5 |
import torch
|
@@ -7,38 +8,18 @@ from torch import nn, Tensor, einsum
|
|
7 |
from typing import Optional
|
8 |
import numpy as np
|
9 |
from dataclasses import dataclass
|
10 |
-
from torch.nn.functional import scaled_dot_product_attention
|
11 |
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
dtype = torch.float32
|
15 |
warnings.filterwarnings("ignore")
|
16 |
logging.basicConfig(level=logging.ERROR)
|
17 |
|
18 |
-
def sinusoids(ctx, dims, max_tscale=10000):
|
19 |
-
assert dims % 2 == 0
|
20 |
-
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
21 |
-
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
|
22 |
-
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
|
23 |
-
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
|
24 |
-
positional_embedding = nn.Parameter(position, requires_grad=True)
|
25 |
-
return positional_embedding
|
26 |
-
|
27 |
-
def get_activation(act: str) -> nn.Module:
|
28 |
-
act_map = {
|
29 |
-
"gelu": nn.GELU(),
|
30 |
-
"relu": nn.ReLU(),
|
31 |
-
"sigmoid": nn.Sigmoid(),
|
32 |
-
"tanh": nn.Tanh(),
|
33 |
-
"swish": nn.SiLU(),
|
34 |
-
"tanhshrink": nn.Tanhshrink(),
|
35 |
-
"softplus": nn.Softplus(),
|
36 |
-
"softshrink": nn.Softshrink(),
|
37 |
-
"leaky_relu": nn.LeakyReLU(),
|
38 |
-
"elu": nn.ELU()
|
39 |
-
}
|
40 |
-
return act_map.get(act, nn.GELU())
|
41 |
-
|
42 |
def there_is_a(val):
|
43 |
return val is not None
|
44 |
|
@@ -105,12 +86,22 @@ class rotary(nn.Module):
|
|
105 |
x1 = x1.view(orig_shape)
|
106 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
107 |
|
108 |
-
def calculate_attention(q, k, v, mask=None, temp=1.0):
|
109 |
scaled_q = q
|
110 |
if temp != 1.0 and temp > 0:
|
111 |
scaled_q = q * (1.0 / temp)**.5
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
return out
|
115 |
|
116 |
class LocalOut(nn.Module):
|
@@ -128,27 +119,24 @@ class LocalOut(nn.Module):
|
|
128 |
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
129 |
|
130 |
class attentionb(nn.Module):
|
131 |
-
def __init__(self, dims: int, head: int, max_iter: int = 3,
|
132 |
-
threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
|
133 |
super(attentionb, self).__init__()
|
134 |
|
135 |
self.head = head
|
136 |
self.dims = dims
|
137 |
self.head_dim = dims // head
|
138 |
-
self.win = 0
|
139 |
|
140 |
self.que = nn.Linear(dims, dims, bias=False)
|
141 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
142 |
self.out = nn.Linear(dims, dims, bias=False)
|
143 |
|
144 |
self.lna = nn.LayerNorm(dims)
|
145 |
-
self.lnb = nn.LayerNorm(
|
146 |
self.rope = rotary(dims, head)
|
147 |
|
148 |
self.max_iter = max_iter
|
149 |
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
|
150 |
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
151 |
-
self.factor = nn.Parameter(torch.tensor(factor), requires_grad=True)
|
152 |
self.local = LocalOut(dims, head)
|
153 |
|
154 |
def update_win(self, win_size=None):
|
@@ -163,7 +151,7 @@ class attentionb(nn.Module):
|
|
163 |
def _focus(self, x, xa = None, mask = None, win_size=None):
|
164 |
|
165 |
q = self.que(self.lna(x))
|
166 |
-
k, v = self.kv(self.lna(x if
|
167 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
168 |
|
169 |
self.scale = q.shape[-1] ** -0.35
|
@@ -171,17 +159,14 @@ class attentionb(nn.Module):
|
|
171 |
k = self.rope(k)
|
172 |
|
173 |
iteration = 0
|
174 |
-
temp = self.temp
|
175 |
prev_out = torch.zeros_like(q)
|
176 |
attn_out = torch.zeros_like(q)
|
177 |
threshold = self.threshold
|
178 |
-
factor = self.factor
|
179 |
curq = q #if curq is None else curq
|
180 |
|
181 |
while iteration < self.max_iter:
|
182 |
-
eff_span =
|
183 |
-
if xa is not None:
|
184 |
-
eff_span = min(eff_span, xa.shape[1])
|
185 |
if eff_span == 0:
|
186 |
break
|
187 |
|
@@ -206,24 +191,18 @@ class attentionb(nn.Module):
|
|
206 |
iter_out = torch.zeros_like(curq)
|
207 |
iter_out[:, :, :eff_span, :] = attn_iter
|
208 |
diff = torch.abs(iter_out - prev_out).mean()
|
209 |
-
|
210 |
-
if diff <
|
211 |
attn_out = iter_out
|
212 |
break
|
213 |
|
214 |
prev_out = iter_out.clone()
|
215 |
curq = curq + iter_out
|
216 |
attn_out = iter_out
|
217 |
-
# if win_size is not None:
|
218 |
-
# if win_size > self.win:
|
219 |
-
# temp += 0.005
|
220 |
-
# else:
|
221 |
-
# temp -= 0.005
|
222 |
-
# self.win = win_size
|
223 |
iteration += 1
|
|
|
224 |
|
225 |
-
|
226 |
-
return out
|
227 |
|
228 |
def _slide_win_local(self, x, mask = None) -> Tensor:
|
229 |
|
@@ -260,26 +239,45 @@ class attentionb(nn.Module):
|
|
260 |
def forward(self, x, xa = None, mask = None):
|
261 |
x = self._slide_win_local(x, mask=None)
|
262 |
xa = self._slide_win_local(xa, mask=None)
|
263 |
-
|
264 |
-
return self.out(
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
class attentiona(nn.Module):
|
267 |
-
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1
|
268 |
super().__init__()
|
|
|
269 |
self.head = head
|
270 |
self.dims = dims
|
271 |
-
self.cross_talk = cross_talk
|
272 |
-
|
273 |
self.que = nn.Linear(dims, dims, bias=False)
|
274 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
275 |
self.out = nn.Linear(dims, dims, bias=False)
|
276 |
-
|
277 |
self.ln = nn.LayerNorm(dims)
|
278 |
self.rope = rotary(dims, head)
|
279 |
|
280 |
-
self.x = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
|
281 |
-
self.xa = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
|
282 |
-
|
283 |
def forward(self, x, xa = None, mask = None):
|
284 |
|
285 |
q = self.que(self.ln(x))
|
@@ -292,16 +290,13 @@ class attentiona(nn.Module):
|
|
292 |
k = self.rope(k)
|
293 |
|
294 |
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
|
|
|
|
|
295 |
|
296 |
-
if there_is_a(mask):
|
297 |
-
i, j = qk.shape[-2:]
|
298 |
-
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
|
299 |
-
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
|
300 |
-
|
301 |
-
qk = torch.nn.functional.softmax(qk, dim=-1)
|
302 |
wv = einsum('b h k q, b h q d -> b h k d', qk, v)
|
303 |
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
304 |
out = self.out(wv)
|
|
|
305 |
|
306 |
class attentiond(nn.Module):
|
307 |
def __init__(self, dims: int, head: int):
|
@@ -324,11 +319,11 @@ class attentiond(nn.Module):
|
|
324 |
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
|
325 |
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
326 |
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
|
327 |
-
qk = einsum('b h
|
328 |
if there_is_a(mask):
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
x = qk.softmax(dim = -1)
|
333 |
xa = qk.softmax(dim = -2)
|
334 |
x = self.x(x)
|
@@ -337,15 +332,13 @@ class attentiond(nn.Module):
|
|
337 |
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
|
338 |
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
|
339 |
out = self.out(x)
|
340 |
-
|
341 |
-
|
342 |
-
return out, outxa, qk
|
343 |
|
344 |
class tgate(nn.Module):
|
345 |
def __init__(self, dims, num_types=4):
|
346 |
super().__init__()
|
347 |
-
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims,
|
348 |
-
self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
|
349 |
def forward(self, x):
|
350 |
types = self.classifier(x)
|
351 |
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
@@ -359,32 +352,34 @@ class residual(nn.Module):
|
|
359 |
self.lna = nn.LayerNorm(dims, bias=False)
|
360 |
self.atta = attentiona(dims, head)
|
361 |
self.attb = attentionb(dims, head, max_iter=1)
|
362 |
-
|
363 |
|
364 |
-
self.tgate = tgate(dims, num_types=
|
365 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
366 |
|
367 |
-
def forward(
|
368 |
-
|
369 |
-
x
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
if xa is not None:
|
375 |
-
x = x + self.
|
376 |
x = x + self.tgate(x)
|
377 |
-
x = x + self.mlp(self.lna(x))
|
378 |
return x
|
379 |
|
380 |
class processor(nn.Module):
|
381 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
|
|
382 |
super(processor, self).__init__()
|
383 |
|
384 |
self.ln = nn.LayerNorm(dims)
|
385 |
self.token = nn.Embedding(vocab, dims)
|
386 |
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
387 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
|
|
388 |
|
389 |
act_fn = get_activation(act)
|
390 |
self.encoder = nn.Sequential(
|
@@ -393,12 +388,13 @@ class processor(nn.Module):
|
|
393 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
394 |
|
395 |
self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
|
396 |
-
self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(
|
397 |
|
|
|
398 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
399 |
self.register_buffer("mask", mask, persistent=False)
|
400 |
|
401 |
-
def forward(self, x, xa, sequential=False, modal=False, kv_cache=None) -> Tensor:
|
402 |
|
403 |
if xa.dim() == 2:
|
404 |
xa = xa.unsqueeze(0)
|
@@ -413,10 +409,22 @@ class processor(nn.Module):
|
|
413 |
xa = block(xa, mask=None)
|
414 |
x = block(x, mask=self.mask)
|
415 |
x = block(x, xa, mask=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
for block in chain(self.blockm or []):
|
418 |
xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
|
419 |
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
422 |
x = self.ln(x)
|
@@ -447,7 +455,7 @@ class Model(nn.Module):
|
|
447 |
def adjust_window(self, loss, ctx):
|
448 |
self.win_size = ((ctx // self.param.head))
|
449 |
if loss < self.best_loss:
|
450 |
-
win_size = (self.win_size * self.factor)
|
451 |
else:
|
452 |
win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
|
453 |
self.win_size = win_size
|
@@ -455,15 +463,25 @@ class Model(nn.Module):
|
|
455 |
self.update(win_size)
|
456 |
return win_size
|
457 |
|
458 |
-
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None):
|
459 |
|
460 |
x = input_ids
|
461 |
-
xa = pitch
|
462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
loss = None
|
464 |
if labels is not None:
|
465 |
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
466 |
-
self.adjust_window(loss=loss.item(), ctx=xa.shape[
|
467 |
return {"logits": logits, "loss": loss}
|
468 |
|
469 |
def _init_weights(self, module):
|
@@ -522,3 +540,116 @@ class Model(nn.Module):
|
|
522 |
hooks.append(layer.v.register_forward_hook(save_to_cache))
|
523 |
self.processor.apply(install_hooks)
|
524 |
return cache, hooks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import warnings
|
3 |
+
import os
|
4 |
import logging
|
5 |
from itertools import chain
|
6 |
import torch
|
|
|
8 |
from typing import Optional
|
9 |
import numpy as np
|
10 |
from dataclasses import dataclass
|
|
|
11 |
from einops import rearrange
|
12 |
+
from datasets import load_dataset, Audio
|
13 |
+
from echoutils import extract_features, setup_tokenizer, compute_metrics, DataCollator, preprocess_logits_for_metrics, sinusoids, get_activation
|
14 |
+
from datetime import datetime
|
15 |
+
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
16 |
+
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
17 |
|
18 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
19 |
dtype = torch.float32
|
20 |
warnings.filterwarnings("ignore")
|
21 |
logging.basicConfig(level=logging.ERROR)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def there_is_a(val):
|
24 |
return val is not None
|
25 |
|
|
|
86 |
x1 = x1.view(orig_shape)
|
87 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
88 |
|
89 |
+
def calculate_attention(q, k, v, mask=None, temp=1.0, pytorch=True):
|
90 |
scaled_q = q
|
91 |
if temp != 1.0 and temp > 0:
|
92 |
scaled_q = q * (1.0 / temp)**.5
|
93 |
+
if pytorch:
|
94 |
+
out = torch.nn.functional.scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
95 |
+
else:
|
96 |
+
scale = q.shape[-1] ** -0.35
|
97 |
+
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
98 |
+
if there_is_a(mask):
|
99 |
+
mask = mask[:qk.shape[2], :qk.shape[2]]
|
100 |
+
qk = qk.masked_fill(mask.bool(), -torch.inf)
|
101 |
+
qk = qk.float()
|
102 |
+
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
|
103 |
+
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
104 |
+
qk = qk.detach()
|
105 |
return out
|
106 |
|
107 |
class LocalOut(nn.Module):
|
|
|
119 |
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
120 |
|
121 |
class attentionb(nn.Module):
|
122 |
+
def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.5, temp = 1.0):
|
|
|
123 |
super(attentionb, self).__init__()
|
124 |
|
125 |
self.head = head
|
126 |
self.dims = dims
|
127 |
self.head_dim = dims // head
|
|
|
128 |
|
129 |
self.que = nn.Linear(dims, dims, bias=False)
|
130 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
131 |
self.out = nn.Linear(dims, dims, bias=False)
|
132 |
|
133 |
self.lna = nn.LayerNorm(dims)
|
134 |
+
self.lnb = nn.LayerNorm(dims // head)
|
135 |
self.rope = rotary(dims, head)
|
136 |
|
137 |
self.max_iter = max_iter
|
138 |
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
|
139 |
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
|
|
140 |
self.local = LocalOut(dims, head)
|
141 |
|
142 |
def update_win(self, win_size=None):
|
|
|
151 |
def _focus(self, x, xa = None, mask = None, win_size=None):
|
152 |
|
153 |
q = self.que(self.lna(x))
|
154 |
+
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
|
155 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
156 |
|
157 |
self.scale = q.shape[-1] ** -0.35
|
|
|
159 |
k = self.rope(k)
|
160 |
|
161 |
iteration = 0
|
162 |
+
temp = self.temp.item()
|
163 |
prev_out = torch.zeros_like(q)
|
164 |
attn_out = torch.zeros_like(q)
|
165 |
threshold = self.threshold
|
|
|
166 |
curq = q #if curq is None else curq
|
167 |
|
168 |
while iteration < self.max_iter:
|
169 |
+
eff_span = curq.shape[2]
|
|
|
|
|
170 |
if eff_span == 0:
|
171 |
break
|
172 |
|
|
|
191 |
iter_out = torch.zeros_like(curq)
|
192 |
iter_out[:, :, :eff_span, :] = attn_iter
|
193 |
diff = torch.abs(iter_out - prev_out).mean()
|
194 |
+
|
195 |
+
if diff < threshold and iteration > 0:
|
196 |
attn_out = iter_out
|
197 |
break
|
198 |
|
199 |
prev_out = iter_out.clone()
|
200 |
curq = curq + iter_out
|
201 |
attn_out = iter_out
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
iteration += 1
|
203 |
+
temp -= 0.005
|
204 |
|
205 |
+
return rearrange(attn_out, 'b h c d -> b c (h d)')
|
|
|
206 |
|
207 |
def _slide_win_local(self, x, mask = None) -> Tensor:
|
208 |
|
|
|
239 |
def forward(self, x, xa = None, mask = None):
|
240 |
x = self._slide_win_local(x, mask=None)
|
241 |
xa = self._slide_win_local(xa, mask=None)
|
242 |
+
out = self._focus(x, xa, mask=None)
|
243 |
+
return self.out(out)
|
244 |
|
245 |
+
def scaled_relu(x, sequence_length):
|
246 |
+
relu_output = torch.relu(x)
|
247 |
+
return relu_output / sequence_length
|
248 |
+
|
249 |
+
def taylor_softmax(x, order=2):
|
250 |
+
taylor_approx = 1.0
|
251 |
+
for i in range(1, order + 1):
|
252 |
+
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
253 |
+
taylor_approx += x**i / factorial_i
|
254 |
+
return taylor_approx / torch.sum(taylor_approx, dim=-1, keepdim=True)
|
255 |
+
|
256 |
+
def taylor_softmax_2nd_order(x):
|
257 |
+
exp_approx = 1 + x + (x**2) / 2
|
258 |
+
return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
|
259 |
+
|
260 |
+
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
261 |
+
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
262 |
+
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
263 |
+
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
|
264 |
+
qk_cosine = qk_cosine + mask
|
265 |
+
weights = F.softmax(qk_cosine, dim=-1)
|
266 |
+
out = torch.matmul(weights, v)
|
267 |
+
return out
|
268 |
+
|
269 |
class attentiona(nn.Module):
|
270 |
+
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1):
|
271 |
super().__init__()
|
272 |
+
|
273 |
self.head = head
|
274 |
self.dims = dims
|
|
|
|
|
275 |
self.que = nn.Linear(dims, dims, bias=False)
|
276 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
277 |
self.out = nn.Linear(dims, dims, bias=False)
|
|
|
278 |
self.ln = nn.LayerNorm(dims)
|
279 |
self.rope = rotary(dims, head)
|
280 |
|
|
|
|
|
|
|
281 |
def forward(self, x, xa = None, mask = None):
|
282 |
|
283 |
q = self.que(self.ln(x))
|
|
|
290 |
k = self.rope(k)
|
291 |
|
292 |
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
|
293 |
+
# qk = torch.nn.functional.softmax(qk, dim=-1)
|
294 |
+
qk = taylor_softmax(qk, order=2)
|
295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
wv = einsum('b h k q, b h q d -> b h k d', qk, v)
|
297 |
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
298 |
out = self.out(wv)
|
299 |
+
return out
|
300 |
|
301 |
class attentiond(nn.Module):
|
302 |
def __init__(self, dims: int, head: int):
|
|
|
319 |
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
|
320 |
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
321 |
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
|
322 |
+
qk = einsum('b h q d, b h k d -> b h q k', qk, qka)
|
323 |
if there_is_a(mask):
|
324 |
+
mask = mask[:qk.shape[2], :qk.shape[2]]
|
325 |
+
qk = qk.masked_fill(mask.bool(), -torch.inf)
|
326 |
+
|
327 |
x = qk.softmax(dim = -1)
|
328 |
xa = qk.softmax(dim = -2)
|
329 |
x = self.x(x)
|
|
|
332 |
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
|
333 |
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
|
334 |
out = self.out(x)
|
335 |
+
return out
|
|
|
|
|
336 |
|
337 |
class tgate(nn.Module):
|
338 |
def __init__(self, dims, num_types=4):
|
339 |
super().__init__()
|
340 |
+
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
|
341 |
+
self.classifier = nn.Sequential(nn.Linear(dims, num_types), torch.nn.functional.Softmax(dim=-1))
|
342 |
def forward(self, x):
|
343 |
types = self.classifier(x)
|
344 |
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
|
|
352 |
self.lna = nn.LayerNorm(dims, bias=False)
|
353 |
self.atta = attentiona(dims, head)
|
354 |
self.attb = attentionb(dims, head, max_iter=1)
|
355 |
+
self.attc = attentiond(dims, head)
|
356 |
|
357 |
+
self.tgate = tgate(dims, num_types=1)
|
358 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
359 |
|
360 |
+
def forward(self, x: Tensor, xa = None, mask = None):
|
361 |
+
|
362 |
+
out = self.atta(x, mask=mask)
|
363 |
+
if x.shape == out.shape:
|
364 |
+
x = x + out
|
365 |
+
else:
|
366 |
+
x = out
|
367 |
if xa is not None:
|
368 |
+
x = x + self.atta(x, xa, mask=None)
|
369 |
x = x + self.tgate(x)
|
370 |
+
x = x + self.mlp(self.lna(x))
|
371 |
return x
|
372 |
|
373 |
class processor(nn.Module):
|
374 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
375 |
+
|
376 |
super(processor, self).__init__()
|
377 |
|
378 |
self.ln = nn.LayerNorm(dims)
|
379 |
self.token = nn.Embedding(vocab, dims)
|
380 |
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
381 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
382 |
+
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
383 |
|
384 |
act_fn = get_activation(act)
|
385 |
self.encoder = nn.Sequential(
|
|
|
388 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
389 |
|
390 |
self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
|
391 |
+
self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(2)])
|
392 |
|
393 |
+
mask = torch.triu(torch.ones(ctx, ctx), diagonal=1)
|
394 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
395 |
self.register_buffer("mask", mask, persistent=False)
|
396 |
|
397 |
+
def forward(self, x, xa, xb, sequential=False, modal=False, kv_cache=None, blend=False) -> Tensor:
|
398 |
|
399 |
if xa.dim() == 2:
|
400 |
xa = xa.unsqueeze(0)
|
|
|
409 |
xa = block(xa, mask=None)
|
410 |
x = block(x, mask=self.mask)
|
411 |
x = block(x, xa, mask=None)
|
412 |
+
if blend:
|
413 |
+
if sequential:
|
414 |
+
y = x
|
415 |
+
else:
|
416 |
+
a = torch.sigmoid(self.blend)
|
417 |
+
x = a * x + (1 - a) * y
|
418 |
|
419 |
for block in chain(self.blockm or []):
|
420 |
xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
|
421 |
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
|
422 |
+
if blend:
|
423 |
+
if sequential:
|
424 |
+
y = x
|
425 |
+
else:
|
426 |
+
a = torch.sigmoid(self.blend)
|
427 |
+
x = a * x + (1 - a) * y
|
428 |
|
429 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
430 |
x = self.ln(x)
|
|
|
455 |
def adjust_window(self, loss, ctx):
|
456 |
self.win_size = ((ctx // self.param.head))
|
457 |
if loss < self.best_loss:
|
458 |
+
win_size = (self.win_size * self.factor)
|
459 |
else:
|
460 |
win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
|
461 |
self.win_size = win_size
|
|
|
463 |
self.update(win_size)
|
464 |
return win_size
|
465 |
|
466 |
+
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
|
467 |
|
468 |
x = input_ids
|
469 |
+
xa = pitch
|
470 |
+
xb = spectrogram
|
471 |
+
|
472 |
+
enc = {}
|
473 |
+
if spectrogram is not None:
|
474 |
+
enc["spectrogram"] = spectrogram
|
475 |
+
if waveform is not None:
|
476 |
+
enc["waveform"] = waveform
|
477 |
+
if pitch is not None:
|
478 |
+
enc["pitch"] = pitch
|
479 |
+
|
480 |
+
logits = self.processor(x, xa, xb)
|
481 |
loss = None
|
482 |
if labels is not None:
|
483 |
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
484 |
+
self.adjust_window(loss=loss.item(), ctx=xa.shape[1])
|
485 |
return {"logits": logits, "loss": loss}
|
486 |
|
487 |
def _init_weights(self, module):
|
|
|
540 |
hooks.append(layer.v.register_forward_hook(save_to_cache))
|
541 |
self.processor.apply(install_hooks)
|
542 |
return cache, hooks
|
543 |
+
|
544 |
+
### "pipeline"
|
545 |
+
|
546 |
+
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
|
547 |
+
|
548 |
+
if load_saved:
|
549 |
+
if cache_dir is None:
|
550 |
+
cache_dir = cache_dir
|
551 |
+
else:
|
552 |
+
cache_dir = cache_dir
|
553 |
+
os.makedirs(cache_dir, exist_ok=True)
|
554 |
+
cache_file_train = os.path.join(cache_dir, "train.arrow")
|
555 |
+
cache_file_test = os.path.join(cache_dir, "test.arrow")
|
556 |
+
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
|
557 |
+
from datasets import Dataset
|
558 |
+
train_dataset = Dataset.load_from_disk(cache_file_train)
|
559 |
+
test_dataset = Dataset.load_from_disk(cache_file_test)
|
560 |
+
return train_dataset, test_dataset
|
561 |
+
def filter_func(x):
|
562 |
+
return (0 < len(x["transcription"]) < max_ctx and
|
563 |
+
len(x["audio"]["array"]) > 0 and
|
564 |
+
len(x["audio"]["array"]) < max_ctx * 160)
|
565 |
+
|
566 |
+
raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription")
|
567 |
+
raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000)
|
568 |
+
|
569 |
+
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
570 |
+
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
571 |
+
train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
|
572 |
+
test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
|
573 |
+
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
|
574 |
+
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
|
575 |
+
return train_dataset, test_dataset
|
576 |
+
|
577 |
+
def main():
|
578 |
+
token = ""
|
579 |
+
log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
|
580 |
+
os.makedirs(log_dir, exist_ok=True)
|
581 |
+
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
582 |
+
|
583 |
+
extract_args = {
|
584 |
+
"waveform": True,
|
585 |
+
"spec": True,
|
586 |
+
"pitch_tokens": True,
|
587 |
+
"pitch": True,
|
588 |
+
"harmonics": False,
|
589 |
+
"aperiodics": False,
|
590 |
+
"phase_mod": False,
|
591 |
+
"crepe": False,
|
592 |
+
"sample_rate": 16000,
|
593 |
+
"hop_length": 256,
|
594 |
+
"mode": "mean",
|
595 |
+
"debug": False,
|
596 |
+
}
|
597 |
+
|
598 |
+
param = Dimensions(vocab=40000, mels=128, ctx=2048, dims=512, head=4, layer=4, act="swish")
|
599 |
+
|
600 |
+
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
601 |
+
load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
|
602 |
+
|
603 |
+
model = Model(param).to('cuda')
|
604 |
+
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
605 |
+
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
606 |
+
|
607 |
+
from functools import partial
|
608 |
+
metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
|
609 |
+
|
610 |
+
training_args = Seq2SeqTrainingArguments(
|
611 |
+
output_dir=log_dir,
|
612 |
+
per_device_train_batch_size=1,
|
613 |
+
per_device_eval_batch_size=1,
|
614 |
+
max_steps=100000,
|
615 |
+
eval_steps=1000,
|
616 |
+
save_steps=1000,
|
617 |
+
warmup_steps=1000,
|
618 |
+
logging_steps=100,
|
619 |
+
logging_dir=log_dir,
|
620 |
+
logging_strategy="steps",
|
621 |
+
eval_strategy="steps",
|
622 |
+
save_strategy="no",
|
623 |
+
report_to=["tensorboard"],
|
624 |
+
push_to_hub=False,
|
625 |
+
save_total_limit=1,
|
626 |
+
label_names=["labels"],
|
627 |
+
save_safetensors=False,
|
628 |
+
eval_on_start=False,
|
629 |
+
batch_eval_metrics=False,
|
630 |
+
disable_tqdm=False,
|
631 |
+
include_tokens_per_second=True,
|
632 |
+
include_num_input_tokens_seen=True,
|
633 |
+
learning_rate=0.00025,
|
634 |
+
weight_decay=0.025,
|
635 |
+
)
|
636 |
+
|
637 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
638 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
639 |
+
|
640 |
+
trainer = Seq2SeqTrainer(
|
641 |
+
args=training_args,
|
642 |
+
model=model,
|
643 |
+
train_dataset=train_dataset,
|
644 |
+
eval_dataset=test_dataset,
|
645 |
+
data_collator=DataCollator(tokenizer=tokenizer),
|
646 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
647 |
+
compute_metrics=metrics_fn,
|
648 |
+
optimizers=(optimizer, scheduler)
|
649 |
+
)
|
650 |
+
|
651 |
+
model.init_weights()
|
652 |
+
trainer.train()
|
653 |
+
|
654 |
+
if __name__ == "__main__":
|
655 |
+
main()
|