Update model_simple.py
Browse files- model_simple.py +82 -327
model_simple.py
CHANGED
@@ -1,25 +1,45 @@
|
|
1 |
|
2 |
import warnings
|
3 |
-
import os
|
4 |
import logging
|
5 |
from itertools import chain
|
6 |
import torch
|
7 |
from torch import nn, Tensor, einsum
|
8 |
-
from typing import Optional
|
9 |
import numpy as np
|
10 |
from dataclasses import dataclass
|
11 |
from einops import rearrange
|
12 |
-
from datasets import load_dataset, Audio
|
13 |
-
from echoutils import extract_features, setup_tokenizer, compute_metrics, DataCollator, preprocess_logits_for_metrics, sinusoids, get_activation
|
14 |
from datetime import datetime
|
|
|
15 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
16 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
17 |
-
|
18 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
19 |
dtype = torch.float32
|
20 |
warnings.filterwarnings("ignore")
|
21 |
logging.basicConfig(level=logging.ERROR)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def there_is_a(val):
|
24 |
return val is not None
|
25 |
|
@@ -33,29 +53,6 @@ class Dimensions:
|
|
33 |
layer: int
|
34 |
act: str
|
35 |
|
36 |
-
def qkv_init(dims, head):
|
37 |
-
head_dim = dims // head
|
38 |
-
q = nn.Linear(dims, dims)
|
39 |
-
k = nn.Linear(dims, dims)
|
40 |
-
v = nn.Linear(dims, dims)
|
41 |
-
o = nn.Linear(dims, dims)
|
42 |
-
lna = nn.LayerNorm(dims)
|
43 |
-
lnb = nn.LayerNorm(dims)
|
44 |
-
lnc = nn.LayerNorm(head_dim)
|
45 |
-
lnd = nn.LayerNorm(head_dim)
|
46 |
-
return q, k, v, o, lna, lnb, lnc, lnd
|
47 |
-
|
48 |
-
def shape(dims, head, q, k, v):
|
49 |
-
batch_size = q.shape[0]
|
50 |
-
seq_len_q = q.shape[1]
|
51 |
-
seq_len_kv = k.shape[1]
|
52 |
-
head_dim = dims // head
|
53 |
-
|
54 |
-
q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
|
55 |
-
k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
|
56 |
-
v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
|
57 |
-
return q, k, v
|
58 |
-
|
59 |
class rotary(nn.Module):
|
60 |
def __init__(self, dims, head):
|
61 |
super(rotary, self).__init__()
|
@@ -63,7 +60,7 @@ class rotary(nn.Module):
|
|
63 |
self.head = head
|
64 |
self.head_dim = dims // head
|
65 |
|
66 |
-
self.theta = nn.Parameter((torch.tensor(
|
67 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
68 |
|
69 |
def _compute_freqs_base(self):
|
@@ -72,10 +69,9 @@ class rotary(nn.Module):
|
|
72 |
|
73 |
def forward(self, x) -> Tensor:
|
74 |
freqs = (self.theta / 220.0) * self.freqs_base
|
75 |
-
|
76 |
pos = torch.arange(x.shape[2], device=device, dtype=dtype)
|
77 |
freqs = pos[:, None] * freqs
|
78 |
-
freqs=torch.polar(torch.ones_like(freqs), freqs)
|
79 |
|
80 |
x1 = x[..., :freqs.shape[-1]*2]
|
81 |
x2 = x[..., freqs.shape[-1]*2:]
|
@@ -86,203 +82,31 @@ class rotary(nn.Module):
|
|
86 |
x1 = x1.view(orig_shape)
|
87 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
88 |
|
89 |
-
|
90 |
-
scaled_q = q
|
91 |
-
if temp != 1.0 and temp > 0:
|
92 |
-
scaled_q = q * (1.0 / temp)**.5
|
93 |
-
if pytorch:
|
94 |
-
out = torch.nn.functional.scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
95 |
-
else:
|
96 |
-
scale = q.shape[-1] ** -0.35
|
97 |
-
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
98 |
-
if there_is_a(mask):
|
99 |
-
mask = mask[:qk.shape[2], :qk.shape[2]]
|
100 |
-
qk = qk.masked_fill(mask.bool(), -torch.inf)
|
101 |
-
qk = qk.float()
|
102 |
-
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
|
103 |
-
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
104 |
-
qk = qk.detach()
|
105 |
-
return out
|
106 |
-
|
107 |
-
class LocalOut(nn.Module):
|
108 |
def __init__(self, dims: int, head: int):
|
109 |
super().__init__()
|
110 |
-
self.head_dim = dims // head
|
111 |
-
self.dims = dims
|
112 |
-
self.q_hd = nn.Linear(self.head_dim, self.head_dim)
|
113 |
-
self.k_hd = nn.Linear(self.head_dim, self.head_dim)
|
114 |
-
self.v_hd = nn.Linear(self.head_dim, self.head_dim)
|
115 |
-
self.out = nn.Linear(self.head_dim, self.head_dim)
|
116 |
-
|
117 |
-
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
|
118 |
-
batch, _, ctx, _ = attn_output.shape
|
119 |
-
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
120 |
-
|
121 |
-
class attentionb(nn.Module):
|
122 |
-
def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.5, temp = 1.0):
|
123 |
-
super(attentionb, self).__init__()
|
124 |
|
125 |
self.head = head
|
126 |
self.dims = dims
|
127 |
self.head_dim = dims // head
|
128 |
|
129 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
130 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
131 |
self.out = nn.Linear(dims, dims, bias=False)
|
132 |
|
133 |
self.lna = nn.LayerNorm(dims)
|
134 |
-
self.lnb = nn.LayerNorm(dims // head)
|
135 |
self.rope = rotary(dims, head)
|
136 |
|
137 |
-
self.max_iter = max_iter
|
138 |
-
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
|
139 |
-
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
140 |
-
self.local = LocalOut(dims, head)
|
141 |
-
|
142 |
-
def update_win(self, win_size=None):
|
143 |
-
if win_size is not None:
|
144 |
-
self.win_size = win_size
|
145 |
-
return win_size
|
146 |
-
elif hasattr(self, 'win_size') and self.win_size is not None:
|
147 |
-
win_size = self.win_size
|
148 |
-
return win_size
|
149 |
-
return None
|
150 |
-
|
151 |
-
def _focus(self, x, xa = None, mask = None, win_size=None):
|
152 |
-
|
153 |
-
q = self.que(self.lna(x))
|
154 |
-
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
|
155 |
-
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
156 |
-
|
157 |
-
self.scale = q.shape[-1] ** -0.35
|
158 |
-
q = self.rope(q)
|
159 |
-
k = self.rope(k)
|
160 |
-
|
161 |
-
iteration = 0
|
162 |
-
temp = self.temp.item()
|
163 |
-
prev_out = torch.zeros_like(q)
|
164 |
-
attn_out = torch.zeros_like(q)
|
165 |
-
threshold = self.threshold
|
166 |
-
curq = q #if curq is None else curq
|
167 |
-
|
168 |
-
while iteration < self.max_iter:
|
169 |
-
eff_span = curq.shape[2]
|
170 |
-
if eff_span == 0:
|
171 |
-
break
|
172 |
-
|
173 |
-
qiter = curq[:, :, :eff_span, :]
|
174 |
-
kiter = k[:, :, :eff_span, :]
|
175 |
-
viter = v[:, :, :eff_span, :]
|
176 |
-
q = self.local.q_hd(qiter)
|
177 |
-
k = self.local.k_hd(kiter)
|
178 |
-
v = self.local.v_hd(viter)
|
179 |
-
|
180 |
-
iter_mask = None
|
181 |
-
if mask is not None:
|
182 |
-
if mask.dim() == 4:
|
183 |
-
iter_mask = mask[:, :, :eff_span, :eff_span]
|
184 |
-
elif mask.dim() == 2:
|
185 |
-
iter_mask = mask[:eff_span, :eff_span]
|
186 |
-
|
187 |
-
attn_iter = calculate_attention(
|
188 |
-
self.lnb(q), self.lnb(k), v,
|
189 |
-
mask=iter_mask, temp=temp)
|
190 |
-
|
191 |
-
iter_out = torch.zeros_like(curq)
|
192 |
-
iter_out[:, :, :eff_span, :] = attn_iter
|
193 |
-
diff = torch.abs(iter_out - prev_out).mean()
|
194 |
-
|
195 |
-
if diff < threshold and iteration > 0:
|
196 |
-
attn_out = iter_out
|
197 |
-
break
|
198 |
-
|
199 |
-
prev_out = iter_out.clone()
|
200 |
-
curq = curq + iter_out
|
201 |
-
attn_out = iter_out
|
202 |
-
iteration += 1
|
203 |
-
temp -= 0.005
|
204 |
-
|
205 |
-
return rearrange(attn_out, 'b h c d -> b c (h d)')
|
206 |
-
|
207 |
-
def _slide_win_local(self, x, mask = None) -> Tensor:
|
208 |
-
|
209 |
-
win = self.update_win()
|
210 |
-
win_size = win if win is not None else self.head_dim
|
211 |
-
span_len = win_size + win_size // self.head
|
212 |
-
|
213 |
-
_, ctx, _ = x.shape
|
214 |
-
out = torch.zeros_like(x)
|
215 |
-
windows = (ctx + win_size - 1) // win_size
|
216 |
-
|
217 |
-
for i in range(windows):
|
218 |
-
qstart = i * win_size
|
219 |
-
qend = min(qstart + win_size, ctx)
|
220 |
-
qlen = qend - qstart
|
221 |
-
if qlen == 0:
|
222 |
-
continue
|
223 |
-
|
224 |
-
kstart = max(0, qend - span_len)
|
225 |
-
qwin = x[:, qstart:qend, :]
|
226 |
-
kwin = x[:, kstart:qend, :]
|
227 |
-
|
228 |
-
win_mask = None
|
229 |
-
if mask is not None:
|
230 |
-
if mask.dim() == 4:
|
231 |
-
win_mask = mask[:, :, qstart:qend, kstart:qend]
|
232 |
-
elif mask.dim() == 2:
|
233 |
-
win_mask = mask[qstart:qend, kstart:qend]
|
234 |
-
|
235 |
-
attn_out = self._focus(x=qwin, xa=kwin, mask=win_mask, win_size=win_size)
|
236 |
-
out[:, qstart:qend, :] = attn_out
|
237 |
-
return out
|
238 |
-
|
239 |
def forward(self, x, xa = None, mask = None):
|
240 |
-
|
241 |
-
xa = self._slide_win_local(xa, mask=None)
|
242 |
-
out = self._focus(x, xa, mask=None)
|
243 |
-
return self.out(out)
|
244 |
-
|
245 |
-
def scaled_relu(x, sequence_length):
|
246 |
-
relu_output = torch.relu(x)
|
247 |
-
return relu_output / sequence_length
|
248 |
-
|
249 |
-
def taylor_softmax(x, order=2):
|
250 |
-
taylor_approx = 1.0
|
251 |
-
for i in range(1, order + 1):
|
252 |
-
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
253 |
-
taylor_approx += x**i / factorial_i
|
254 |
-
return taylor_approx / torch.sum(taylor_approx, dim=-1, keepdim=True)
|
255 |
-
|
256 |
-
def taylor_softmax_2nd_order(x):
|
257 |
-
exp_approx = 1 + x + (x**2) / 2
|
258 |
-
return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
|
259 |
-
|
260 |
-
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
261 |
-
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
262 |
-
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
263 |
-
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
|
264 |
-
qk_cosine = qk_cosine + mask
|
265 |
-
weights = F.softmax(qk_cosine, dim=-1)
|
266 |
-
out = torch.matmul(weights, v)
|
267 |
-
return out
|
268 |
-
|
269 |
-
class attentiona(nn.Module):
|
270 |
-
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1):
|
271 |
-
super().__init__()
|
272 |
-
|
273 |
-
self.head = head
|
274 |
-
self.dims = dims
|
275 |
-
self.que = nn.Linear(dims, dims, bias=False)
|
276 |
-
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
277 |
-
self.out = nn.Linear(dims, dims, bias=False)
|
278 |
-
self.ln = nn.LayerNorm(dims)
|
279 |
-
self.rope = rotary(dims, head)
|
280 |
-
|
281 |
-
def forward(self, x, xa = None, mask = None):
|
282 |
-
|
283 |
-
q = self.que(self.ln(x))
|
284 |
-
k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
285 |
|
|
|
|
|
286 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
287 |
scale = q.shape[-1] ** -0.5
|
288 |
|
@@ -291,59 +115,29 @@ class attentiona(nn.Module):
|
|
291 |
|
292 |
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
|
293 |
|
|
|
|
|
|
|
|
|
294 |
if there_is_a(mask):
|
295 |
-
|
296 |
-
|
|
|
|
|
297 |
|
298 |
-
qk =
|
|
|
299 |
|
300 |
wv = einsum('b h k q, b h q d -> b h k d', qk, v)
|
301 |
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
302 |
out = self.out(wv)
|
303 |
return out
|
304 |
|
305 |
-
class attentiond(nn.Module):
|
306 |
-
def __init__(self, dims: int, head: int):
|
307 |
-
super().__init__()
|
308 |
-
self.head = head
|
309 |
-
self.dims = dims
|
310 |
-
|
311 |
-
self.que = nn.Linear(dims, dims, bias=False)
|
312 |
-
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
313 |
-
self.out = nn.Linear(dims, dims, bias=False)
|
314 |
-
|
315 |
-
self.ln = nn.LayerNorm(dims)
|
316 |
-
self.rope = rotary(dims, head)
|
317 |
-
|
318 |
-
self.x = nn.Conv2d(head, head, 1, bias = False)
|
319 |
-
self.xa = nn.Conv2d(head, head, 1, bias = False)
|
320 |
-
|
321 |
-
def forward(self, x, xa = None, mask = None):
|
322 |
-
|
323 |
-
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
|
324 |
-
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
325 |
-
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
|
326 |
-
qk = einsum('b h q d, b h k d -> b h q k', qk, qka)
|
327 |
-
|
328 |
-
if there_is_a(mask):
|
329 |
-
mask = mask[:qk.shape[2], :qk.shape[2]]
|
330 |
-
qk = qk.masked_fill(mask.bool(), -torch.inf)
|
331 |
-
|
332 |
-
x = qk.softmax(dim = -1)
|
333 |
-
xa = qk.softmax(dim = -2)
|
334 |
-
x = self.x(x)
|
335 |
-
xa = self.xa(xa)
|
336 |
-
x = einsum('b h i j, b h j d -> b h i d', x, va)
|
337 |
-
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
|
338 |
-
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
|
339 |
-
out = self.out(x)
|
340 |
-
return out
|
341 |
-
|
342 |
class tgate(nn.Module):
|
343 |
def __init__(self, dims, num_types=4):
|
344 |
super().__init__()
|
345 |
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
|
346 |
-
self.classifier = nn.Sequential(nn.Linear(dims, num_types),
|
347 |
def forward(self, x):
|
348 |
types = self.classifier(x)
|
349 |
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
@@ -356,9 +150,7 @@ class residual(nn.Module):
|
|
356 |
|
357 |
self.lna = nn.LayerNorm(dims, bias=False)
|
358 |
self.atta = attentiona(dims, head)
|
359 |
-
|
360 |
-
self.attc = attentiond(dims, head)
|
361 |
-
|
362 |
self.tgate = tgate(dims, num_types=1)
|
363 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
364 |
|
@@ -371,18 +163,19 @@ class residual(nn.Module):
|
|
371 |
x = out
|
372 |
if xa is not None:
|
373 |
x = x + self.atta(x, xa, mask=None)
|
|
|
374 |
x = x + self.tgate(x)
|
375 |
x = x + self.mlp(self.lna(x))
|
376 |
return x
|
377 |
|
378 |
class processor(nn.Module):
|
379 |
-
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
380 |
-
|
381 |
super(processor, self).__init__()
|
382 |
|
383 |
self.ln = nn.LayerNorm(dims)
|
384 |
self.token = nn.Embedding(vocab, dims)
|
385 |
-
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
|
386 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
387 |
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
388 |
|
@@ -392,17 +185,12 @@ class processor(nn.Module):
|
|
392 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
393 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
394 |
|
395 |
-
self.
|
396 |
-
|
397 |
-
|
398 |
-
mask = torch.triu(torch.ones(ctx, ctx), diagonal=1)
|
399 |
-
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
400 |
self.register_buffer("mask", mask, persistent=False)
|
401 |
|
402 |
-
def forward(self, x, xa,
|
403 |
-
|
404 |
-
if xa.dim() == 2:
|
405 |
-
xa = xa.unsqueeze(0)
|
406 |
|
407 |
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
408 |
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
|
@@ -410,9 +198,9 @@ class processor(nn.Module):
|
|
410 |
xa = self.encoder(xa).permute(0, 2, 1)
|
411 |
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
412 |
|
413 |
-
for block in chain(self.
|
414 |
xa = block(xa, mask=None)
|
415 |
-
x = block(x, mask=
|
416 |
x = block(x, xa, mask=None)
|
417 |
if blend:
|
418 |
if sequential:
|
@@ -421,8 +209,7 @@ class processor(nn.Module):
|
|
421 |
a = torch.sigmoid(self.blend)
|
422 |
x = a * x + (1 - a) * y
|
423 |
|
424 |
-
|
425 |
-
xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
|
426 |
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
|
427 |
if blend:
|
428 |
if sequential:
|
@@ -449,31 +236,11 @@ class Model(nn.Module):
|
|
449 |
layer=param.layer,
|
450 |
act=param.act)
|
451 |
|
452 |
-
self.best_loss = float('inf')
|
453 |
-
self.factor = nn.Parameter(torch.tensor(2), requires_grad=False)
|
454 |
-
|
455 |
-
def update(self, win_size):
|
456 |
-
for name, module in self.processor.named_modules():
|
457 |
-
if isinstance(module, (attentionb)):
|
458 |
-
module.update_win(win_size)
|
459 |
-
|
460 |
-
def adjust_window(self, loss, ctx):
|
461 |
-
self.win_size = ((ctx // self.param.head))
|
462 |
-
if loss < self.best_loss:
|
463 |
-
win_size = (self.win_size * self.factor)
|
464 |
-
else:
|
465 |
-
win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
|
466 |
-
self.win_size = win_size
|
467 |
-
self.best_loss = loss
|
468 |
-
self.update(win_size)
|
469 |
-
return win_size
|
470 |
-
|
471 |
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
|
472 |
|
473 |
x = input_ids
|
474 |
xa = pitch
|
475 |
-
|
476 |
-
|
477 |
enc = {}
|
478 |
if spectrogram is not None:
|
479 |
enc["spectrogram"] = spectrogram
|
@@ -482,11 +249,11 @@ class Model(nn.Module):
|
|
482 |
if pitch is not None:
|
483 |
enc["pitch"] = pitch
|
484 |
|
485 |
-
logits = self.processor(x, xa,
|
486 |
loss = None
|
487 |
if labels is not None:
|
488 |
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
489 |
-
|
490 |
return {"logits": logits, "loss": loss}
|
491 |
|
492 |
def _init_weights(self, module):
|
@@ -529,25 +296,6 @@ class Model(nn.Module):
|
|
529 |
if count > 0:
|
530 |
print(f"{module_type}: {count}")
|
531 |
|
532 |
-
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
533 |
-
cache = {**cache} if cache is not None else {}
|
534 |
-
hooks = []
|
535 |
-
def save_to_cache(module, _, output):
|
536 |
-
if module not in cache or output.shape[1] > self.param.ctx:
|
537 |
-
cache[module] = output
|
538 |
-
else:
|
539 |
-
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
540 |
-
return cache[module]
|
541 |
-
|
542 |
-
def install_hooks(layer: nn.Module):
|
543 |
-
if isinstance(layer, attentiona):
|
544 |
-
hooks.append(layer.k.register_forward_hook(save_to_cache))
|
545 |
-
hooks.append(layer.v.register_forward_hook(save_to_cache))
|
546 |
-
self.processor.apply(install_hooks)
|
547 |
-
return cache, hooks
|
548 |
-
|
549 |
-
### "pipeline"
|
550 |
-
|
551 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
|
552 |
|
553 |
if load_saved:
|
@@ -555,21 +303,26 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
|
|
555 |
cache_dir = cache_dir
|
556 |
else:
|
557 |
cache_dir = cache_dir
|
|
|
558 |
os.makedirs(cache_dir, exist_ok=True)
|
559 |
cache_file_train = os.path.join(cache_dir, "train.arrow")
|
560 |
cache_file_test = os.path.join(cache_dir, "test.arrow")
|
|
|
561 |
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
|
562 |
from datasets import Dataset
|
563 |
train_dataset = Dataset.load_from_disk(cache_file_train)
|
564 |
test_dataset = Dataset.load_from_disk(cache_file_test)
|
565 |
return train_dataset, test_dataset
|
|
|
566 |
def filter_func(x):
|
567 |
return (0 < len(x["transcription"]) < max_ctx and
|
568 |
len(x["audio"]["array"]) > 0 and
|
569 |
len(x["audio"]["array"]) < max_ctx * 160)
|
570 |
|
571 |
-
raw_train
|
572 |
-
|
|
|
|
|
573 |
|
574 |
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
575 |
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
@@ -586,9 +339,9 @@ def main():
|
|
586 |
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
587 |
|
588 |
extract_args = {
|
589 |
-
"waveform":
|
590 |
-
"spec":
|
591 |
-
"pitch_tokens":
|
592 |
"pitch": True,
|
593 |
"harmonics": False,
|
594 |
"aperiodics": False,
|
@@ -616,11 +369,11 @@ def main():
|
|
616 |
output_dir=log_dir,
|
617 |
per_device_train_batch_size=1,
|
618 |
per_device_eval_batch_size=1,
|
619 |
-
max_steps=
|
620 |
-
eval_steps=
|
621 |
-
save_steps=
|
622 |
-
warmup_steps=
|
623 |
-
logging_steps=
|
624 |
logging_dir=log_dir,
|
625 |
logging_strategy="steps",
|
626 |
eval_strategy="steps",
|
@@ -632,14 +385,15 @@ def main():
|
|
632 |
save_safetensors=False,
|
633 |
eval_on_start=False,
|
634 |
batch_eval_metrics=False,
|
635 |
-
disable_tqdm=False,
|
636 |
include_tokens_per_second=True,
|
637 |
include_num_input_tokens_seen=True,
|
638 |
learning_rate=0.00025,
|
639 |
weight_decay=0.025,
|
640 |
)
|
641 |
|
642 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-
|
|
|
643 |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
644 |
|
645 |
trainer = Seq2SeqTrainer(
|
@@ -658,3 +412,4 @@ def main():
|
|
658 |
|
659 |
if __name__ == "__main__":
|
660 |
main()
|
|
|
|
1 |
|
2 |
import warnings
|
|
|
3 |
import logging
|
4 |
from itertools import chain
|
5 |
import torch
|
6 |
from torch import nn, Tensor, einsum
|
|
|
7 |
import numpy as np
|
8 |
from dataclasses import dataclass
|
9 |
from einops import rearrange
|
|
|
|
|
10 |
from datetime import datetime
|
11 |
+
from echoutils import *
|
12 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
13 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
|
|
14 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
15 |
dtype = torch.float32
|
16 |
warnings.filterwarnings("ignore")
|
17 |
logging.basicConfig(level=logging.ERROR)
|
18 |
|
19 |
+
def sinusoids(ctx, dims, max_tscale=10000):
|
20 |
+
assert dims % 2 == 0
|
21 |
+
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
22 |
+
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
|
23 |
+
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
|
24 |
+
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
|
25 |
+
positional_embedding = nn.Parameter(position, requires_grad=True)
|
26 |
+
return positional_embedding
|
27 |
+
|
28 |
+
def get_activation(act: str) -> nn.Module:
|
29 |
+
act_map = {
|
30 |
+
"gelu": nn.GELU(),
|
31 |
+
"relu": nn.ReLU(),
|
32 |
+
"sigmoid": nn.Sigmoid(),
|
33 |
+
"tanh": nn.Tanh(),
|
34 |
+
"swish": nn.SiLU(),
|
35 |
+
"tanhshrink": nn.Tanhshrink(),
|
36 |
+
"softplus": nn.Softplus(),
|
37 |
+
"softshrink": nn.Softshrink(),
|
38 |
+
"leaky_relu": nn.LeakyReLU(),
|
39 |
+
"elu": nn.ELU()
|
40 |
+
}
|
41 |
+
return act_map.get(act, nn.GELU())
|
42 |
+
|
43 |
def there_is_a(val):
|
44 |
return val is not None
|
45 |
|
|
|
53 |
layer: int
|
54 |
act: str
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
class rotary(nn.Module):
|
57 |
def __init__(self, dims, head):
|
58 |
super(rotary, self).__init__()
|
|
|
60 |
self.head = head
|
61 |
self.head_dim = dims // head
|
62 |
|
63 |
+
self.theta = nn.Parameter((torch.tensor(16000, device=device, dtype=dtype)), requires_grad=True)
|
64 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
65 |
|
66 |
def _compute_freqs_base(self):
|
|
|
69 |
|
70 |
def forward(self, x) -> Tensor:
|
71 |
freqs = (self.theta / 220.0) * self.freqs_base
|
|
|
72 |
pos = torch.arange(x.shape[2], device=device, dtype=dtype)
|
73 |
freqs = pos[:, None] * freqs
|
74 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
75 |
|
76 |
x1 = x[..., :freqs.shape[-1]*2]
|
77 |
x2 = x[..., freqs.shape[-1]*2:]
|
|
|
82 |
x1 = x1.view(orig_shape)
|
83 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
84 |
|
85 |
+
class attentiona(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def __init__(self, dims: int, head: int):
|
87 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
self.head = head
|
90 |
self.dims = dims
|
91 |
self.head_dim = dims // head
|
92 |
|
93 |
+
self.pad_token = 0
|
94 |
+
self.zmin = 1e-6
|
95 |
+
self.zmax = 1e-5
|
96 |
+
self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False)
|
97 |
+
|
98 |
+
self.q = nn.Linear(dims, dims, bias=False)
|
99 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
100 |
self.out = nn.Linear(dims, dims, bias=False)
|
101 |
|
102 |
self.lna = nn.LayerNorm(dims)
|
|
|
103 |
self.rope = rotary(dims, head)
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def forward(self, x, xa = None, mask = None):
|
106 |
+
zero = self.zero
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
q = self.q(self.lna(x))
|
109 |
+
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
|
110 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
111 |
scale = q.shape[-1] ** -0.5
|
112 |
|
|
|
115 |
|
116 |
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
|
117 |
|
118 |
+
scale = torch.ones_like(k[:, :, :, 0])
|
119 |
+
zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5)
|
120 |
+
scale[k[:, :, :, 0].float() == 0] = zero
|
121 |
+
|
122 |
if there_is_a(mask):
|
123 |
+
i, j = qk.shape[-2:]
|
124 |
+
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
|
125 |
+
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max) * scale.unsqueeze(-2).expand(qk.shape)
|
126 |
+
qk = F.sigmoid(qk)
|
127 |
|
128 |
+
qk = qk * scale.unsqueeze(-2)
|
129 |
+
qk = taylor_softmax(qk, order=2)
|
130 |
|
131 |
wv = einsum('b h k q, b h q d -> b h k d', qk, v)
|
132 |
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
133 |
out = self.out(wv)
|
134 |
return out
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
class tgate(nn.Module):
|
137 |
def __init__(self, dims, num_types=4):
|
138 |
super().__init__()
|
139 |
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
|
140 |
+
self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
|
141 |
def forward(self, x):
|
142 |
types = self.classifier(x)
|
143 |
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
|
|
150 |
|
151 |
self.lna = nn.LayerNorm(dims, bias=False)
|
152 |
self.atta = attentiona(dims, head)
|
153 |
+
|
|
|
|
|
154 |
self.tgate = tgate(dims, num_types=1)
|
155 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
156 |
|
|
|
163 |
x = out
|
164 |
if xa is not None:
|
165 |
x = x + self.atta(x, xa, mask=None)
|
166 |
+
|
167 |
x = x + self.tgate(x)
|
168 |
x = x + self.mlp(self.lna(x))
|
169 |
return x
|
170 |
|
171 |
class processor(nn.Module):
|
172 |
+
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu", modal=True):
|
|
|
173 |
super(processor, self).__init__()
|
174 |
|
175 |
self.ln = nn.LayerNorm(dims)
|
176 |
self.token = nn.Embedding(vocab, dims)
|
177 |
+
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
178 |
+
|
179 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
180 |
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
181 |
|
|
|
185 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
186 |
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
187 |
|
188 |
+
self.block = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
|
189 |
+
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
|
|
|
|
|
|
190 |
self.register_buffer("mask", mask, persistent=False)
|
191 |
|
192 |
+
def forward(self, x, xa, enc=None, sequential=False, modal=True, blend=False, kv_cache=None) -> Tensor:
|
193 |
+
mask = self.mask[:x.shape[1], :x.shape[1]]
|
|
|
|
|
194 |
|
195 |
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
196 |
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
|
|
|
198 |
xa = self.encoder(xa).permute(0, 2, 1)
|
199 |
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
200 |
|
201 |
+
for block in chain(self.block or []):
|
202 |
xa = block(xa, mask=None)
|
203 |
+
x = block(x, mask=mask)
|
204 |
x = block(x, xa, mask=None)
|
205 |
if blend:
|
206 |
if sequential:
|
|
|
209 |
a = torch.sigmoid(self.blend)
|
210 |
x = a * x + (1 - a) * y
|
211 |
|
212 |
+
xm = block(torch.cat([x, xa], dim=1), mask=mask) if modal else None
|
|
|
213 |
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
|
214 |
if blend:
|
215 |
if sequential:
|
|
|
236 |
layer=param.layer,
|
237 |
act=param.act)
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
|
240 |
|
241 |
x = input_ids
|
242 |
xa = pitch
|
243 |
+
|
|
|
244 |
enc = {}
|
245 |
if spectrogram is not None:
|
246 |
enc["spectrogram"] = spectrogram
|
|
|
249 |
if pitch is not None:
|
250 |
enc["pitch"] = pitch
|
251 |
|
252 |
+
logits = self.processor(x, xa, enc)
|
253 |
loss = None
|
254 |
if labels is not None:
|
255 |
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
256 |
+
|
257 |
return {"logits": logits, "loss": loss}
|
258 |
|
259 |
def _init_weights(self, module):
|
|
|
296 |
if count > 0:
|
297 |
print(f"{module_type}: {count}")
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
|
300 |
|
301 |
if load_saved:
|
|
|
303 |
cache_dir = cache_dir
|
304 |
else:
|
305 |
cache_dir = cache_dir
|
306 |
+
|
307 |
os.makedirs(cache_dir, exist_ok=True)
|
308 |
cache_file_train = os.path.join(cache_dir, "train.arrow")
|
309 |
cache_file_test = os.path.join(cache_dir, "test.arrow")
|
310 |
+
|
311 |
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
|
312 |
from datasets import Dataset
|
313 |
train_dataset = Dataset.load_from_disk(cache_file_train)
|
314 |
test_dataset = Dataset.load_from_disk(cache_file_test)
|
315 |
return train_dataset, test_dataset
|
316 |
+
|
317 |
def filter_func(x):
|
318 |
return (0 < len(x["transcription"]) < max_ctx and
|
319 |
len(x["audio"]["array"]) > 0 and
|
320 |
len(x["audio"]["array"]) < max_ctx * 160)
|
321 |
|
322 |
+
raw_train = load_dataset(
|
323 |
+
"google/fleurs", "en_us", token=token, split="train", streaming=streaming).take(1000)
|
324 |
+
raw_test = load_dataset(
|
325 |
+
"google/fleurs", "en_us", token=token, split="test", streaming=streaming).take(100)
|
326 |
|
327 |
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
328 |
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
|
|
339 |
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
|
340 |
|
341 |
extract_args = {
|
342 |
+
"waveform": False,
|
343 |
+
"spec": False,
|
344 |
+
"pitch_tokens": False,
|
345 |
"pitch": True,
|
346 |
"harmonics": False,
|
347 |
"aperiodics": False,
|
|
|
369 |
output_dir=log_dir,
|
370 |
per_device_train_batch_size=1,
|
371 |
per_device_eval_batch_size=1,
|
372 |
+
max_steps=1000,
|
373 |
+
eval_steps=100,
|
374 |
+
save_steps=100,
|
375 |
+
warmup_steps=10,
|
376 |
+
logging_steps=10,
|
377 |
logging_dir=log_dir,
|
378 |
logging_strategy="steps",
|
379 |
eval_strategy="steps",
|
|
|
385 |
save_safetensors=False,
|
386 |
eval_on_start=False,
|
387 |
batch_eval_metrics=False,
|
388 |
+
disable_tqdm=False,
|
389 |
include_tokens_per_second=True,
|
390 |
include_num_input_tokens_seen=True,
|
391 |
learning_rate=0.00025,
|
392 |
weight_decay=0.025,
|
393 |
)
|
394 |
|
395 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-10, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
396 |
+
|
397 |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
398 |
|
399 |
trainer = Seq2SeqTrainer(
|
|
|
412 |
|
413 |
if __name__ == "__main__":
|
414 |
main()
|
415 |
+
|