Sin2pi commited on
Commit
79996fa
·
verified ·
1 Parent(s): bcea466

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +82 -327
model_simple.py CHANGED
@@ -1,25 +1,45 @@
1
 
2
  import warnings
3
- import os
4
  import logging
5
  from itertools import chain
6
  import torch
7
  from torch import nn, Tensor, einsum
8
- from typing import Optional
9
  import numpy as np
10
  from dataclasses import dataclass
11
  from einops import rearrange
12
- from datasets import load_dataset, Audio
13
- from echoutils import extract_features, setup_tokenizer, compute_metrics, DataCollator, preprocess_logits_for_metrics, sinusoids, get_activation
14
  from datetime import datetime
 
15
  from transformers.trainer_seq2seq import Seq2SeqTrainer
16
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
17
-
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
  dtype = torch.float32
20
  warnings.filterwarnings("ignore")
21
  logging.basicConfig(level=logging.ERROR)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def there_is_a(val):
24
  return val is not None
25
 
@@ -33,29 +53,6 @@ class Dimensions:
33
  layer: int
34
  act: str
35
 
36
- def qkv_init(dims, head):
37
- head_dim = dims // head
38
- q = nn.Linear(dims, dims)
39
- k = nn.Linear(dims, dims)
40
- v = nn.Linear(dims, dims)
41
- o = nn.Linear(dims, dims)
42
- lna = nn.LayerNorm(dims)
43
- lnb = nn.LayerNorm(dims)
44
- lnc = nn.LayerNorm(head_dim)
45
- lnd = nn.LayerNorm(head_dim)
46
- return q, k, v, o, lna, lnb, lnc, lnd
47
-
48
- def shape(dims, head, q, k, v):
49
- batch_size = q.shape[0]
50
- seq_len_q = q.shape[1]
51
- seq_len_kv = k.shape[1]
52
- head_dim = dims // head
53
-
54
- q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
55
- k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
56
- v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
57
- return q, k, v
58
-
59
  class rotary(nn.Module):
60
  def __init__(self, dims, head):
61
  super(rotary, self).__init__()
@@ -63,7 +60,7 @@ class rotary(nn.Module):
63
  self.head = head
64
  self.head_dim = dims // head
65
 
66
- self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
67
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
68
 
69
  def _compute_freqs_base(self):
@@ -72,10 +69,9 @@ class rotary(nn.Module):
72
 
73
  def forward(self, x) -> Tensor:
74
  freqs = (self.theta / 220.0) * self.freqs_base
75
-
76
  pos = torch.arange(x.shape[2], device=device, dtype=dtype)
77
  freqs = pos[:, None] * freqs
78
- freqs=torch.polar(torch.ones_like(freqs), freqs)
79
 
80
  x1 = x[..., :freqs.shape[-1]*2]
81
  x2 = x[..., freqs.shape[-1]*2:]
@@ -86,203 +82,31 @@ class rotary(nn.Module):
86
  x1 = x1.view(orig_shape)
87
  return torch.cat([x1.type_as(x), x2], dim=-1)
88
 
89
- def calculate_attention(q, k, v, mask=None, temp=1.0, pytorch=True):
90
- scaled_q = q
91
- if temp != 1.0 and temp > 0:
92
- scaled_q = q * (1.0 / temp)**.5
93
- if pytorch:
94
- out = torch.nn.functional.scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
95
- else:
96
- scale = q.shape[-1] ** -0.35
97
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
98
- if there_is_a(mask):
99
- mask = mask[:qk.shape[2], :qk.shape[2]]
100
- qk = qk.masked_fill(mask.bool(), -torch.inf)
101
- qk = qk.float()
102
- w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
103
- out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
104
- qk = qk.detach()
105
- return out
106
-
107
- class LocalOut(nn.Module):
108
  def __init__(self, dims: int, head: int):
109
  super().__init__()
110
- self.head_dim = dims // head
111
- self.dims = dims
112
- self.q_hd = nn.Linear(self.head_dim, self.head_dim)
113
- self.k_hd = nn.Linear(self.head_dim, self.head_dim)
114
- self.v_hd = nn.Linear(self.head_dim, self.head_dim)
115
- self.out = nn.Linear(self.head_dim, self.head_dim)
116
-
117
- def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
118
- batch, _, ctx, _ = attn_output.shape
119
- return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
120
-
121
- class attentionb(nn.Module):
122
- def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.5, temp = 1.0):
123
- super(attentionb, self).__init__()
124
 
125
  self.head = head
126
  self.dims = dims
127
  self.head_dim = dims // head
128
 
129
- self.que = nn.Linear(dims, dims, bias=False)
 
 
 
 
 
130
  self.kv = nn.Linear(dims, dims * 2, bias=False)
131
  self.out = nn.Linear(dims, dims, bias=False)
132
 
133
  self.lna = nn.LayerNorm(dims)
134
- self.lnb = nn.LayerNorm(dims // head)
135
  self.rope = rotary(dims, head)
136
 
137
- self.max_iter = max_iter
138
- self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
139
- self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
140
- self.local = LocalOut(dims, head)
141
-
142
- def update_win(self, win_size=None):
143
- if win_size is not None:
144
- self.win_size = win_size
145
- return win_size
146
- elif hasattr(self, 'win_size') and self.win_size is not None:
147
- win_size = self.win_size
148
- return win_size
149
- return None
150
-
151
- def _focus(self, x, xa = None, mask = None, win_size=None):
152
-
153
- q = self.que(self.lna(x))
154
- k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
155
- q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
156
-
157
- self.scale = q.shape[-1] ** -0.35
158
- q = self.rope(q)
159
- k = self.rope(k)
160
-
161
- iteration = 0
162
- temp = self.temp.item()
163
- prev_out = torch.zeros_like(q)
164
- attn_out = torch.zeros_like(q)
165
- threshold = self.threshold
166
- curq = q #if curq is None else curq
167
-
168
- while iteration < self.max_iter:
169
- eff_span = curq.shape[2]
170
- if eff_span == 0:
171
- break
172
-
173
- qiter = curq[:, :, :eff_span, :]
174
- kiter = k[:, :, :eff_span, :]
175
- viter = v[:, :, :eff_span, :]
176
- q = self.local.q_hd(qiter)
177
- k = self.local.k_hd(kiter)
178
- v = self.local.v_hd(viter)
179
-
180
- iter_mask = None
181
- if mask is not None:
182
- if mask.dim() == 4:
183
- iter_mask = mask[:, :, :eff_span, :eff_span]
184
- elif mask.dim() == 2:
185
- iter_mask = mask[:eff_span, :eff_span]
186
-
187
- attn_iter = calculate_attention(
188
- self.lnb(q), self.lnb(k), v,
189
- mask=iter_mask, temp=temp)
190
-
191
- iter_out = torch.zeros_like(curq)
192
- iter_out[:, :, :eff_span, :] = attn_iter
193
- diff = torch.abs(iter_out - prev_out).mean()
194
-
195
- if diff < threshold and iteration > 0:
196
- attn_out = iter_out
197
- break
198
-
199
- prev_out = iter_out.clone()
200
- curq = curq + iter_out
201
- attn_out = iter_out
202
- iteration += 1
203
- temp -= 0.005
204
-
205
- return rearrange(attn_out, 'b h c d -> b c (h d)')
206
-
207
- def _slide_win_local(self, x, mask = None) -> Tensor:
208
-
209
- win = self.update_win()
210
- win_size = win if win is not None else self.head_dim
211
- span_len = win_size + win_size // self.head
212
-
213
- _, ctx, _ = x.shape
214
- out = torch.zeros_like(x)
215
- windows = (ctx + win_size - 1) // win_size
216
-
217
- for i in range(windows):
218
- qstart = i * win_size
219
- qend = min(qstart + win_size, ctx)
220
- qlen = qend - qstart
221
- if qlen == 0:
222
- continue
223
-
224
- kstart = max(0, qend - span_len)
225
- qwin = x[:, qstart:qend, :]
226
- kwin = x[:, kstart:qend, :]
227
-
228
- win_mask = None
229
- if mask is not None:
230
- if mask.dim() == 4:
231
- win_mask = mask[:, :, qstart:qend, kstart:qend]
232
- elif mask.dim() == 2:
233
- win_mask = mask[qstart:qend, kstart:qend]
234
-
235
- attn_out = self._focus(x=qwin, xa=kwin, mask=win_mask, win_size=win_size)
236
- out[:, qstart:qend, :] = attn_out
237
- return out
238
-
239
  def forward(self, x, xa = None, mask = None):
240
- x = self._slide_win_local(x, mask=None)
241
- xa = self._slide_win_local(xa, mask=None)
242
- out = self._focus(x, xa, mask=None)
243
- return self.out(out)
244
-
245
- def scaled_relu(x, sequence_length):
246
- relu_output = torch.relu(x)
247
- return relu_output / sequence_length
248
-
249
- def taylor_softmax(x, order=2):
250
- taylor_approx = 1.0
251
- for i in range(1, order + 1):
252
- factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
253
- taylor_approx += x**i / factorial_i
254
- return taylor_approx / torch.sum(taylor_approx, dim=-1, keepdim=True)
255
-
256
- def taylor_softmax_2nd_order(x):
257
- exp_approx = 1 + x + (x**2) / 2
258
- return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
259
-
260
- def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
261
- q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
262
- k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
263
- qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
264
- qk_cosine = qk_cosine + mask
265
- weights = F.softmax(qk_cosine, dim=-1)
266
- out = torch.matmul(weights, v)
267
- return out
268
-
269
- class attentiona(nn.Module):
270
- def __init__(self, dims: int, head: int, dropout_rate: float = 0.1):
271
- super().__init__()
272
-
273
- self.head = head
274
- self.dims = dims
275
- self.que = nn.Linear(dims, dims, bias=False)
276
- self.kv = nn.Linear(dims, dims * 2, bias=False)
277
- self.out = nn.Linear(dims, dims, bias=False)
278
- self.ln = nn.LayerNorm(dims)
279
- self.rope = rotary(dims, head)
280
-
281
- def forward(self, x, xa = None, mask = None):
282
-
283
- q = self.que(self.ln(x))
284
- k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
285
 
 
 
286
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
287
  scale = q.shape[-1] ** -0.5
288
 
@@ -291,59 +115,29 @@ class attentiona(nn.Module):
291
 
292
  qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
293
 
 
 
 
 
294
  if there_is_a(mask):
295
- mask = mask[:qk.shape[2], :qk.shape[2]]
296
- qk = qk.masked_fill(mask.bool(), -torch.inf)
 
 
297
 
298
- qk = taylor_softmax(qk, order=2) # qk = torch.nn.functional.softmax(qk, dim=-1)
 
299
 
300
  wv = einsum('b h k q, b h q d -> b h k d', qk, v)
301
  wv = rearrange(wv, 'b h c d -> b c (h d)')
302
  out = self.out(wv)
303
  return out
304
 
305
- class attentiond(nn.Module):
306
- def __init__(self, dims: int, head: int):
307
- super().__init__()
308
- self.head = head
309
- self.dims = dims
310
-
311
- self.que = nn.Linear(dims, dims, bias=False)
312
- self.kv = nn.Linear(dims, dims * 2, bias=False)
313
- self.out = nn.Linear(dims, dims, bias=False)
314
-
315
- self.ln = nn.LayerNorm(dims)
316
- self.rope = rotary(dims, head)
317
-
318
- self.x = nn.Conv2d(head, head, 1, bias = False)
319
- self.xa = nn.Conv2d(head, head, 1, bias = False)
320
-
321
- def forward(self, x, xa = None, mask = None):
322
-
323
- qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
324
- qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
325
- qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
326
- qk = einsum('b h q d, b h k d -> b h q k', qk, qka)
327
-
328
- if there_is_a(mask):
329
- mask = mask[:qk.shape[2], :qk.shape[2]]
330
- qk = qk.masked_fill(mask.bool(), -torch.inf)
331
-
332
- x = qk.softmax(dim = -1)
333
- xa = qk.softmax(dim = -2)
334
- x = self.x(x)
335
- xa = self.xa(xa)
336
- x = einsum('b h i j, b h j d -> b h i d', x, va)
337
- xa = einsum('b h j i, b h j d -> b h i d', xa, v)
338
- x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
339
- out = self.out(x)
340
- return out
341
-
342
  class tgate(nn.Module):
343
  def __init__(self, dims, num_types=4):
344
  super().__init__()
345
  self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
346
- self.classifier = nn.Sequential(nn.Linear(dims, num_types), torch.nn.functional.Softmax(dim=-1))
347
  def forward(self, x):
348
  types = self.classifier(x)
349
  gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
@@ -356,9 +150,7 @@ class residual(nn.Module):
356
 
357
  self.lna = nn.LayerNorm(dims, bias=False)
358
  self.atta = attentiona(dims, head)
359
- self.attb = attentionb(dims, head, max_iter=1)
360
- self.attc = attentiond(dims, head)
361
-
362
  self.tgate = tgate(dims, num_types=1)
363
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
364
 
@@ -371,18 +163,19 @@ class residual(nn.Module):
371
  x = out
372
  if xa is not None:
373
  x = x + self.atta(x, xa, mask=None)
 
374
  x = x + self.tgate(x)
375
  x = x + self.mlp(self.lna(x))
376
  return x
377
 
378
  class processor(nn.Module):
379
- def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
380
-
381
  super(processor, self).__init__()
382
 
383
  self.ln = nn.LayerNorm(dims)
384
  self.token = nn.Embedding(vocab, dims)
385
- self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
386
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
387
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
388
 
@@ -392,17 +185,12 @@ class processor(nn.Module):
392
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
393
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
394
 
395
- self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
396
- self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(2)])
397
-
398
- mask = torch.triu(torch.ones(ctx, ctx), diagonal=1)
399
- mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
400
  self.register_buffer("mask", mask, persistent=False)
401
 
402
- def forward(self, x, xa, xb, sequential=False, modal=False, blend=False, kv_cache=None) -> Tensor:
403
-
404
- if xa.dim() == 2:
405
- xa = xa.unsqueeze(0)
406
 
407
  offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
408
  x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
@@ -410,9 +198,9 @@ class processor(nn.Module):
410
  xa = self.encoder(xa).permute(0, 2, 1)
411
  xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
412
 
413
- for block in chain(self.blocka or []):
414
  xa = block(xa, mask=None)
415
- x = block(x, mask=self.mask)
416
  x = block(x, xa, mask=None)
417
  if blend:
418
  if sequential:
@@ -421,8 +209,7 @@ class processor(nn.Module):
421
  a = torch.sigmoid(self.blend)
422
  x = a * x + (1 - a) * y
423
 
424
- for block in chain(self.blockm or []):
425
- xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
426
  x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
427
  if blend:
428
  if sequential:
@@ -449,31 +236,11 @@ class Model(nn.Module):
449
  layer=param.layer,
450
  act=param.act)
451
 
452
- self.best_loss = float('inf')
453
- self.factor = nn.Parameter(torch.tensor(2), requires_grad=False)
454
-
455
- def update(self, win_size):
456
- for name, module in self.processor.named_modules():
457
- if isinstance(module, (attentionb)):
458
- module.update_win(win_size)
459
-
460
- def adjust_window(self, loss, ctx):
461
- self.win_size = ((ctx // self.param.head))
462
- if loss < self.best_loss:
463
- win_size = (self.win_size * self.factor)
464
- else:
465
- win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
466
- self.win_size = win_size
467
- self.best_loss = loss
468
- self.update(win_size)
469
- return win_size
470
-
471
  def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
472
 
473
  x = input_ids
474
  xa = pitch
475
- xb = spectrogram
476
-
477
  enc = {}
478
  if spectrogram is not None:
479
  enc["spectrogram"] = spectrogram
@@ -482,11 +249,11 @@ class Model(nn.Module):
482
  if pitch is not None:
483
  enc["pitch"] = pitch
484
 
485
- logits = self.processor(x, xa, xb)
486
  loss = None
487
  if labels is not None:
488
  loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
489
- self.adjust_window(loss=loss.item(), ctx=xa.shape[1])
490
  return {"logits": logits, "loss": loss}
491
 
492
  def _init_weights(self, module):
@@ -529,25 +296,6 @@ class Model(nn.Module):
529
  if count > 0:
530
  print(f"{module_type}: {count}")
531
 
532
- def install_kv_cache_hooks(self, cache: Optional[dict] = None):
533
- cache = {**cache} if cache is not None else {}
534
- hooks = []
535
- def save_to_cache(module, _, output):
536
- if module not in cache or output.shape[1] > self.param.ctx:
537
- cache[module] = output
538
- else:
539
- cache[module] = torch.cat([cache[module], output], dim=1).detach()
540
- return cache[module]
541
-
542
- def install_hooks(layer: nn.Module):
543
- if isinstance(layer, attentiona):
544
- hooks.append(layer.k.register_forward_hook(save_to_cache))
545
- hooks.append(layer.v.register_forward_hook(save_to_cache))
546
- self.processor.apply(install_hooks)
547
- return cache, hooks
548
-
549
- ### "pipeline"
550
-
551
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
552
 
553
  if load_saved:
@@ -555,21 +303,26 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
555
  cache_dir = cache_dir
556
  else:
557
  cache_dir = cache_dir
 
558
  os.makedirs(cache_dir, exist_ok=True)
559
  cache_file_train = os.path.join(cache_dir, "train.arrow")
560
  cache_file_test = os.path.join(cache_dir, "test.arrow")
 
561
  if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
562
  from datasets import Dataset
563
  train_dataset = Dataset.load_from_disk(cache_file_train)
564
  test_dataset = Dataset.load_from_disk(cache_file_test)
565
  return train_dataset, test_dataset
 
566
  def filter_func(x):
567
  return (0 < len(x["transcription"]) < max_ctx and
568
  len(x["audio"]["array"]) > 0 and
569
  len(x["audio"]["array"]) < max_ctx * 160)
570
 
571
- raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription")
572
- raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000)
 
 
573
 
574
  raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
575
  raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
@@ -586,9 +339,9 @@ def main():
586
  tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
587
 
588
  extract_args = {
589
- "waveform": True,
590
- "spec": True,
591
- "pitch_tokens": True,
592
  "pitch": True,
593
  "harmonics": False,
594
  "aperiodics": False,
@@ -616,11 +369,11 @@ def main():
616
  output_dir=log_dir,
617
  per_device_train_batch_size=1,
618
  per_device_eval_batch_size=1,
619
- max_steps=100000,
620
- eval_steps=1000,
621
- save_steps=1000,
622
- warmup_steps=1000,
623
- logging_steps=100,
624
  logging_dir=log_dir,
625
  logging_strategy="steps",
626
  eval_strategy="steps",
@@ -632,14 +385,15 @@ def main():
632
  save_safetensors=False,
633
  eval_on_start=False,
634
  batch_eval_metrics=False,
635
- disable_tqdm=False,
636
  include_tokens_per_second=True,
637
  include_num_input_tokens_seen=True,
638
  learning_rate=0.00025,
639
  weight_decay=0.025,
640
  )
641
 
642
- optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
 
643
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
644
 
645
  trainer = Seq2SeqTrainer(
@@ -658,3 +412,4 @@ def main():
658
 
659
  if __name__ == "__main__":
660
  main()
 
 
1
 
2
  import warnings
 
3
  import logging
4
  from itertools import chain
5
  import torch
6
  from torch import nn, Tensor, einsum
 
7
  import numpy as np
8
  from dataclasses import dataclass
9
  from einops import rearrange
 
 
10
  from datetime import datetime
11
+ from echoutils import *
12
  from transformers.trainer_seq2seq import Seq2SeqTrainer
13
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
 
14
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
  dtype = torch.float32
16
  warnings.filterwarnings("ignore")
17
  logging.basicConfig(level=logging.ERROR)
18
 
19
+ def sinusoids(ctx, dims, max_tscale=10000):
20
+ assert dims % 2 == 0
21
+ pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
22
+ tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
23
+ scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
24
+ position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
25
+ positional_embedding = nn.Parameter(position, requires_grad=True)
26
+ return positional_embedding
27
+
28
+ def get_activation(act: str) -> nn.Module:
29
+ act_map = {
30
+ "gelu": nn.GELU(),
31
+ "relu": nn.ReLU(),
32
+ "sigmoid": nn.Sigmoid(),
33
+ "tanh": nn.Tanh(),
34
+ "swish": nn.SiLU(),
35
+ "tanhshrink": nn.Tanhshrink(),
36
+ "softplus": nn.Softplus(),
37
+ "softshrink": nn.Softshrink(),
38
+ "leaky_relu": nn.LeakyReLU(),
39
+ "elu": nn.ELU()
40
+ }
41
+ return act_map.get(act, nn.GELU())
42
+
43
  def there_is_a(val):
44
  return val is not None
45
 
 
53
  layer: int
54
  act: str
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class rotary(nn.Module):
57
  def __init__(self, dims, head):
58
  super(rotary, self).__init__()
 
60
  self.head = head
61
  self.head_dim = dims // head
62
 
63
+ self.theta = nn.Parameter((torch.tensor(16000, device=device, dtype=dtype)), requires_grad=True)
64
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
65
 
66
  def _compute_freqs_base(self):
 
69
 
70
  def forward(self, x) -> Tensor:
71
  freqs = (self.theta / 220.0) * self.freqs_base
 
72
  pos = torch.arange(x.shape[2], device=device, dtype=dtype)
73
  freqs = pos[:, None] * freqs
74
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
75
 
76
  x1 = x[..., :freqs.shape[-1]*2]
77
  x2 = x[..., freqs.shape[-1]*2:]
 
82
  x1 = x1.view(orig_shape)
83
  return torch.cat([x1.type_as(x), x2], dim=-1)
84
 
85
+ class attentiona(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def __init__(self, dims: int, head: int):
87
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  self.head = head
90
  self.dims = dims
91
  self.head_dim = dims // head
92
 
93
+ self.pad_token = 0
94
+ self.zmin = 1e-6
95
+ self.zmax = 1e-5
96
+ self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False)
97
+
98
+ self.q = nn.Linear(dims, dims, bias=False)
99
  self.kv = nn.Linear(dims, dims * 2, bias=False)
100
  self.out = nn.Linear(dims, dims, bias=False)
101
 
102
  self.lna = nn.LayerNorm(dims)
 
103
  self.rope = rotary(dims, head)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def forward(self, x, xa = None, mask = None):
106
+ zero = self.zero
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ q = self.q(self.lna(x))
109
+ k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
110
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
111
  scale = q.shape[-1] ** -0.5
112
 
 
115
 
116
  qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
117
 
118
+ scale = torch.ones_like(k[:, :, :, 0])
119
+ zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5)
120
+ scale[k[:, :, :, 0].float() == 0] = zero
121
+
122
  if there_is_a(mask):
123
+ i, j = qk.shape[-2:]
124
+ mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
125
+ qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max) * scale.unsqueeze(-2).expand(qk.shape)
126
+ qk = F.sigmoid(qk)
127
 
128
+ qk = qk * scale.unsqueeze(-2)
129
+ qk = taylor_softmax(qk, order=2)
130
 
131
  wv = einsum('b h k q, b h q d -> b h k d', qk, v)
132
  wv = rearrange(wv, 'b h c d -> b c (h d)')
133
  out = self.out(wv)
134
  return out
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  class tgate(nn.Module):
137
  def __init__(self, dims, num_types=4):
138
  super().__init__()
139
  self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
140
+ self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
141
  def forward(self, x):
142
  types = self.classifier(x)
143
  gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
 
150
 
151
  self.lna = nn.LayerNorm(dims, bias=False)
152
  self.atta = attentiona(dims, head)
153
+
 
 
154
  self.tgate = tgate(dims, num_types=1)
155
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
156
 
 
163
  x = out
164
  if xa is not None:
165
  x = x + self.atta(x, xa, mask=None)
166
+
167
  x = x + self.tgate(x)
168
  x = x + self.mlp(self.lna(x))
169
  return x
170
 
171
  class processor(nn.Module):
172
+ def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu", modal=True):
 
173
  super(processor, self).__init__()
174
 
175
  self.ln = nn.LayerNorm(dims)
176
  self.token = nn.Embedding(vocab, dims)
177
+ self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
178
+
179
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
180
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
181
 
 
185
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
186
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
187
 
188
+ self.block = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
189
+ mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
 
 
 
190
  self.register_buffer("mask", mask, persistent=False)
191
 
192
+ def forward(self, x, xa, enc=None, sequential=False, modal=True, blend=False, kv_cache=None) -> Tensor:
193
+ mask = self.mask[:x.shape[1], :x.shape[1]]
 
 
194
 
195
  offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
196
  x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
 
198
  xa = self.encoder(xa).permute(0, 2, 1)
199
  xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
200
 
201
+ for block in chain(self.block or []):
202
  xa = block(xa, mask=None)
203
+ x = block(x, mask=mask)
204
  x = block(x, xa, mask=None)
205
  if blend:
206
  if sequential:
 
209
  a = torch.sigmoid(self.blend)
210
  x = a * x + (1 - a) * y
211
 
212
+ xm = block(torch.cat([x, xa], dim=1), mask=mask) if modal else None
 
213
  x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
214
  if blend:
215
  if sequential:
 
236
  layer=param.layer,
237
  act=param.act)
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
240
 
241
  x = input_ids
242
  xa = pitch
243
+
 
244
  enc = {}
245
  if spectrogram is not None:
246
  enc["spectrogram"] = spectrogram
 
249
  if pitch is not None:
250
  enc["pitch"] = pitch
251
 
252
+ logits = self.processor(x, xa, enc)
253
  loss = None
254
  if labels is not None:
255
  loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
256
+
257
  return {"logits": logits, "loss": loss}
258
 
259
  def _init_weights(self, module):
 
296
  if count > 0:
297
  print(f"{module_type}: {count}")
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
300
 
301
  if load_saved:
 
303
  cache_dir = cache_dir
304
  else:
305
  cache_dir = cache_dir
306
+
307
  os.makedirs(cache_dir, exist_ok=True)
308
  cache_file_train = os.path.join(cache_dir, "train.arrow")
309
  cache_file_test = os.path.join(cache_dir, "test.arrow")
310
+
311
  if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
312
  from datasets import Dataset
313
  train_dataset = Dataset.load_from_disk(cache_file_train)
314
  test_dataset = Dataset.load_from_disk(cache_file_test)
315
  return train_dataset, test_dataset
316
+
317
  def filter_func(x):
318
  return (0 < len(x["transcription"]) < max_ctx and
319
  len(x["audio"]["array"]) > 0 and
320
  len(x["audio"]["array"]) < max_ctx * 160)
321
 
322
+ raw_train = load_dataset(
323
+ "google/fleurs", "en_us", token=token, split="train", streaming=streaming).take(1000)
324
+ raw_test = load_dataset(
325
+ "google/fleurs", "en_us", token=token, split="test", streaming=streaming).take(100)
326
 
327
  raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
328
  raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
 
339
  tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
340
 
341
  extract_args = {
342
+ "waveform": False,
343
+ "spec": False,
344
+ "pitch_tokens": False,
345
  "pitch": True,
346
  "harmonics": False,
347
  "aperiodics": False,
 
369
  output_dir=log_dir,
370
  per_device_train_batch_size=1,
371
  per_device_eval_batch_size=1,
372
+ max_steps=1000,
373
+ eval_steps=100,
374
+ save_steps=100,
375
+ warmup_steps=10,
376
+ logging_steps=10,
377
  logging_dir=log_dir,
378
  logging_strategy="steps",
379
  eval_strategy="steps",
 
385
  save_safetensors=False,
386
  eval_on_start=False,
387
  batch_eval_metrics=False,
388
+ disable_tqdm=False,
389
  include_tokens_per_second=True,
390
  include_num_input_tokens_seen=True,
391
  learning_rate=0.00025,
392
  weight_decay=0.025,
393
  )
394
 
395
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-10, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
396
+
397
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
398
 
399
  trainer = Seq2SeqTrainer(
 
412
 
413
  if __name__ == "__main__":
414
  main()
415
+