Sin2pi commited on
Commit
d7c7a93
·
verified ·
1 Parent(s): 203f8e5

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +222 -91
model_simple.py CHANGED
@@ -1,5 +1,6 @@
1
 
2
  import warnings
 
3
  import logging
4
  from itertools import chain
5
  import torch
@@ -7,38 +8,18 @@ from torch import nn, Tensor, einsum
7
  from typing import Optional
8
  import numpy as np
9
  from dataclasses import dataclass
10
- from torch.nn.functional import scaled_dot_product_attention
11
  from einops import rearrange
 
 
 
 
 
12
 
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.float32
15
  warnings.filterwarnings("ignore")
16
  logging.basicConfig(level=logging.ERROR)
17
 
18
- def sinusoids(ctx, dims, max_tscale=10000):
19
- assert dims % 2 == 0
20
- pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
21
- tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
22
- scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
23
- position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
24
- positional_embedding = nn.Parameter(position, requires_grad=True)
25
- return positional_embedding
26
-
27
- def get_activation(act: str) -> nn.Module:
28
- act_map = {
29
- "gelu": nn.GELU(),
30
- "relu": nn.ReLU(),
31
- "sigmoid": nn.Sigmoid(),
32
- "tanh": nn.Tanh(),
33
- "swish": nn.SiLU(),
34
- "tanhshrink": nn.Tanhshrink(),
35
- "softplus": nn.Softplus(),
36
- "softshrink": nn.Softshrink(),
37
- "leaky_relu": nn.LeakyReLU(),
38
- "elu": nn.ELU()
39
- }
40
- return act_map.get(act, nn.GELU())
41
-
42
  def there_is_a(val):
43
  return val is not None
44
 
@@ -105,12 +86,22 @@ class rotary(nn.Module):
105
  x1 = x1.view(orig_shape)
106
  return torch.cat([x1.type_as(x), x2], dim=-1)
107
 
108
- def calculate_attention(q, k, v, mask=None, temp=1.0):
109
  scaled_q = q
110
  if temp != 1.0 and temp > 0:
111
  scaled_q = q * (1.0 / temp)**.5
112
- print(temp)
113
- out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
 
 
 
 
 
 
 
 
 
 
114
  return out
115
 
116
  class LocalOut(nn.Module):
@@ -128,27 +119,24 @@ class LocalOut(nn.Module):
128
  return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
129
 
130
  class attentionb(nn.Module):
131
- def __init__(self, dims: int, head: int, max_iter: int = 3,
132
- threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
133
  super(attentionb, self).__init__()
134
 
135
  self.head = head
136
  self.dims = dims
137
  self.head_dim = dims // head
138
- self.win = 0
139
 
140
  self.que = nn.Linear(dims, dims, bias=False)
141
  self.kv = nn.Linear(dims, dims * 2, bias=False)
142
  self.out = nn.Linear(dims, dims, bias=False)
143
 
144
  self.lna = nn.LayerNorm(dims)
145
- self.lnb = nn.LayerNorm(self.head_dim)
146
  self.rope = rotary(dims, head)
147
 
148
  self.max_iter = max_iter
149
  self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
150
  self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
151
- self.factor = nn.Parameter(torch.tensor(factor), requires_grad=True)
152
  self.local = LocalOut(dims, head)
153
 
154
  def update_win(self, win_size=None):
@@ -163,7 +151,7 @@ class attentionb(nn.Module):
163
  def _focus(self, x, xa = None, mask = None, win_size=None):
164
 
165
  q = self.que(self.lna(x))
166
- k, v = self.kv(self.lna(x if not xa else xa)).chunk(2, dim=-1)
167
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
168
 
169
  self.scale = q.shape[-1] ** -0.35
@@ -171,17 +159,14 @@ class attentionb(nn.Module):
171
  k = self.rope(k)
172
 
173
  iteration = 0
174
- temp = self.temp
175
  prev_out = torch.zeros_like(q)
176
  attn_out = torch.zeros_like(q)
177
  threshold = self.threshold
178
- factor = self.factor
179
  curq = q #if curq is None else curq
180
 
181
  while iteration < self.max_iter:
182
- eff_span = min(curq.shape[1], k.shape[1])
183
- if xa is not None:
184
- eff_span = min(eff_span, xa.shape[1])
185
  if eff_span == 0:
186
  break
187
 
@@ -206,24 +191,18 @@ class attentionb(nn.Module):
206
  iter_out = torch.zeros_like(curq)
207
  iter_out[:, :, :eff_span, :] = attn_iter
208
  diff = torch.abs(iter_out - prev_out).mean()
209
- dthresh = threshold + factor * diff
210
- if diff < dthresh and iteration > 0:
211
  attn_out = iter_out
212
  break
213
 
214
  prev_out = iter_out.clone()
215
  curq = curq + iter_out
216
  attn_out = iter_out
217
- # if win_size is not None:
218
- # if win_size > self.win:
219
- # temp += 0.005
220
- # else:
221
- # temp -= 0.005
222
- # self.win = win_size
223
  iteration += 1
 
224
 
225
- out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
226
- return out
227
 
228
  def _slide_win_local(self, x, mask = None) -> Tensor:
229
 
@@ -260,26 +239,45 @@ class attentionb(nn.Module):
260
  def forward(self, x, xa = None, mask = None):
261
  x = self._slide_win_local(x, mask=None)
262
  xa = self._slide_win_local(xa, mask=None)
263
- output = self._focus(x, xa, mask=None)
264
- return self.out(output)
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  class attentiona(nn.Module):
267
- def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
268
  super().__init__()
 
269
  self.head = head
270
  self.dims = dims
271
- self.cross_talk = cross_talk
272
-
273
  self.que = nn.Linear(dims, dims, bias=False)
274
  self.kv = nn.Linear(dims, dims * 2, bias=False)
275
  self.out = nn.Linear(dims, dims, bias=False)
276
-
277
  self.ln = nn.LayerNorm(dims)
278
  self.rope = rotary(dims, head)
279
 
280
- self.x = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
281
- self.xa = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
282
-
283
  def forward(self, x, xa = None, mask = None):
284
 
285
  q = self.que(self.ln(x))
@@ -292,16 +290,13 @@ class attentiona(nn.Module):
292
  k = self.rope(k)
293
 
294
  qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
 
 
295
 
296
- if there_is_a(mask):
297
- i, j = qk.shape[-2:]
298
- mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
299
- qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
300
-
301
- qk = torch.nn.functional.softmax(qk, dim=-1)
302
  wv = einsum('b h k q, b h q d -> b h k d', qk, v)
303
  wv = rearrange(wv, 'b h c d -> b c (h d)')
304
  out = self.out(wv)
 
305
 
306
  class attentiond(nn.Module):
307
  def __init__(self, dims: int, head: int):
@@ -324,11 +319,11 @@ class attentiond(nn.Module):
324
  qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
325
  qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
326
  qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
327
- qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
328
  if there_is_a(mask):
329
- i, j = qk.shape[-2:]
330
- mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
331
- qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
332
  x = qk.softmax(dim = -1)
333
  xa = qk.softmax(dim = -2)
334
  x = self.x(x)
@@ -337,15 +332,13 @@ class attentiond(nn.Module):
337
  xa = einsum('b h j i, b h j d -> b h i d', xa, v)
338
  x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
339
  out = self.out(x)
340
- outxa = self.out(xa)
341
-
342
- return out, outxa, qk
343
 
344
  class tgate(nn.Module):
345
  def __init__(self, dims, num_types=4):
346
  super().__init__()
347
- self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
348
- self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
349
  def forward(self, x):
350
  types = self.classifier(x)
351
  gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
@@ -359,32 +352,34 @@ class residual(nn.Module):
359
  self.lna = nn.LayerNorm(dims, bias=False)
360
  self.atta = attentiona(dims, head)
361
  self.attb = attentionb(dims, head, max_iter=1)
362
- # self.attc = attentiona(dims, head, cross_talk=True)
363
 
364
- self.tgate = tgate(dims, num_types=4)
365
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
366
 
367
- def forward(
368
- self,
369
- x: Tensor,
370
- xa: Optional[Tensor] = None,
371
- mask: Optional[Tensor] = None,
372
- ):
373
- x = x + self.atta(x, mask=mask)[0]
374
  if xa is not None:
375
- x = x + self.attb(x, xa, mask=None)
376
  x = x + self.tgate(x)
377
- x = x + self.mlp(self.lna(x))
378
  return x
379
 
380
  class processor(nn.Module):
381
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
 
382
  super(processor, self).__init__()
383
 
384
  self.ln = nn.LayerNorm(dims)
385
  self.token = nn.Embedding(vocab, dims)
386
  self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
387
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
 
388
 
389
  act_fn = get_activation(act)
390
  self.encoder = nn.Sequential(
@@ -393,12 +388,13 @@ class processor(nn.Module):
393
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
394
 
395
  self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
396
- self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer // 2)])
397
 
 
398
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
399
  self.register_buffer("mask", mask, persistent=False)
400
 
401
- def forward(self, x, xa, sequential=False, modal=False, kv_cache=None) -> Tensor:
402
 
403
  if xa.dim() == 2:
404
  xa = xa.unsqueeze(0)
@@ -413,10 +409,22 @@ class processor(nn.Module):
413
  xa = block(xa, mask=None)
414
  x = block(x, mask=self.mask)
415
  x = block(x, xa, mask=None)
 
 
 
 
 
 
416
 
417
  for block in chain(self.blockm or []):
418
  xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
419
  x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
 
 
 
 
 
 
420
 
421
  x = nn.functional.dropout(x, p=0.001, training=self.training)
422
  x = self.ln(x)
@@ -447,7 +455,7 @@ class Model(nn.Module):
447
  def adjust_window(self, loss, ctx):
448
  self.win_size = ((ctx // self.param.head))
449
  if loss < self.best_loss:
450
- win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
451
  else:
452
  win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
453
  self.win_size = win_size
@@ -455,15 +463,25 @@ class Model(nn.Module):
455
  self.update(win_size)
456
  return win_size
457
 
458
- def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None):
459
 
460
  x = input_ids
461
- xa = pitch if pitch is not None else spectrogram # xb = pitch_tokens
462
- logits = self.processor(x, xa)
 
 
 
 
 
 
 
 
 
 
463
  loss = None
464
  if labels is not None:
465
  loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
466
- self.adjust_window(loss=loss.item(), ctx=xa.shape[2])
467
  return {"logits": logits, "loss": loss}
468
 
469
  def _init_weights(self, module):
@@ -522,3 +540,116 @@ class Model(nn.Module):
522
  hooks.append(layer.v.register_forward_hook(save_to_cache))
523
  self.processor.apply(install_hooks)
524
  return cache, hooks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import warnings
3
+ import os
4
  import logging
5
  from itertools import chain
6
  import torch
 
8
  from typing import Optional
9
  import numpy as np
10
  from dataclasses import dataclass
 
11
  from einops import rearrange
12
+ from datasets import load_dataset, Audio
13
+ from echoutils import extract_features, setup_tokenizer, compute_metrics, DataCollator, preprocess_logits_for_metrics, sinusoids, get_activation
14
+ from datetime import datetime
15
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
16
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
17
 
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
  dtype = torch.float32
20
  warnings.filterwarnings("ignore")
21
  logging.basicConfig(level=logging.ERROR)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def there_is_a(val):
24
  return val is not None
25
 
 
86
  x1 = x1.view(orig_shape)
87
  return torch.cat([x1.type_as(x), x2], dim=-1)
88
 
89
+ def calculate_attention(q, k, v, mask=None, temp=1.0, pytorch=True):
90
  scaled_q = q
91
  if temp != 1.0 and temp > 0:
92
  scaled_q = q * (1.0 / temp)**.5
93
+ if pytorch:
94
+ out = torch.nn.functional.scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
95
+ else:
96
+ scale = q.shape[-1] ** -0.35
97
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
98
+ if there_is_a(mask):
99
+ mask = mask[:qk.shape[2], :qk.shape[2]]
100
+ qk = qk.masked_fill(mask.bool(), -torch.inf)
101
+ qk = qk.float()
102
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
103
+ out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
104
+ qk = qk.detach()
105
  return out
106
 
107
  class LocalOut(nn.Module):
 
119
  return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
120
 
121
  class attentionb(nn.Module):
122
+ def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.5, temp = 1.0):
 
123
  super(attentionb, self).__init__()
124
 
125
  self.head = head
126
  self.dims = dims
127
  self.head_dim = dims // head
 
128
 
129
  self.que = nn.Linear(dims, dims, bias=False)
130
  self.kv = nn.Linear(dims, dims * 2, bias=False)
131
  self.out = nn.Linear(dims, dims, bias=False)
132
 
133
  self.lna = nn.LayerNorm(dims)
134
+ self.lnb = nn.LayerNorm(dims // head)
135
  self.rope = rotary(dims, head)
136
 
137
  self.max_iter = max_iter
138
  self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
139
  self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
 
140
  self.local = LocalOut(dims, head)
141
 
142
  def update_win(self, win_size=None):
 
151
  def _focus(self, x, xa = None, mask = None, win_size=None):
152
 
153
  q = self.que(self.lna(x))
154
+ k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
155
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
156
 
157
  self.scale = q.shape[-1] ** -0.35
 
159
  k = self.rope(k)
160
 
161
  iteration = 0
162
+ temp = self.temp.item()
163
  prev_out = torch.zeros_like(q)
164
  attn_out = torch.zeros_like(q)
165
  threshold = self.threshold
 
166
  curq = q #if curq is None else curq
167
 
168
  while iteration < self.max_iter:
169
+ eff_span = curq.shape[2]
 
 
170
  if eff_span == 0:
171
  break
172
 
 
191
  iter_out = torch.zeros_like(curq)
192
  iter_out[:, :, :eff_span, :] = attn_iter
193
  diff = torch.abs(iter_out - prev_out).mean()
194
+
195
+ if diff < threshold and iteration > 0:
196
  attn_out = iter_out
197
  break
198
 
199
  prev_out = iter_out.clone()
200
  curq = curq + iter_out
201
  attn_out = iter_out
 
 
 
 
 
 
202
  iteration += 1
203
+ temp -= 0.005
204
 
205
+ return rearrange(attn_out, 'b h c d -> b c (h d)')
 
206
 
207
  def _slide_win_local(self, x, mask = None) -> Tensor:
208
 
 
239
  def forward(self, x, xa = None, mask = None):
240
  x = self._slide_win_local(x, mask=None)
241
  xa = self._slide_win_local(xa, mask=None)
242
+ out = self._focus(x, xa, mask=None)
243
+ return self.out(out)
244
 
245
+ def scaled_relu(x, sequence_length):
246
+ relu_output = torch.relu(x)
247
+ return relu_output / sequence_length
248
+
249
+ def taylor_softmax(x, order=2):
250
+ taylor_approx = 1.0
251
+ for i in range(1, order + 1):
252
+ factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
253
+ taylor_approx += x**i / factorial_i
254
+ return taylor_approx / torch.sum(taylor_approx, dim=-1, keepdim=True)
255
+
256
+ def taylor_softmax_2nd_order(x):
257
+ exp_approx = 1 + x + (x**2) / 2
258
+ return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
259
+
260
+ def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
261
+ q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
262
+ k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
263
+ qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
264
+ qk_cosine = qk_cosine + mask
265
+ weights = F.softmax(qk_cosine, dim=-1)
266
+ out = torch.matmul(weights, v)
267
+ return out
268
+
269
  class attentiona(nn.Module):
270
+ def __init__(self, dims: int, head: int, dropout_rate: float = 0.1):
271
  super().__init__()
272
+
273
  self.head = head
274
  self.dims = dims
 
 
275
  self.que = nn.Linear(dims, dims, bias=False)
276
  self.kv = nn.Linear(dims, dims * 2, bias=False)
277
  self.out = nn.Linear(dims, dims, bias=False)
 
278
  self.ln = nn.LayerNorm(dims)
279
  self.rope = rotary(dims, head)
280
 
 
 
 
281
  def forward(self, x, xa = None, mask = None):
282
 
283
  q = self.que(self.ln(x))
 
290
  k = self.rope(k)
291
 
292
  qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
293
+ # qk = torch.nn.functional.softmax(qk, dim=-1)
294
+ qk = taylor_softmax(qk, order=2)
295
 
 
 
 
 
 
 
296
  wv = einsum('b h k q, b h q d -> b h k d', qk, v)
297
  wv = rearrange(wv, 'b h c d -> b c (h d)')
298
  out = self.out(wv)
299
+ return out
300
 
301
  class attentiond(nn.Module):
302
  def __init__(self, dims: int, head: int):
 
319
  qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
320
  qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
321
  qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
322
+ qk = einsum('b h q d, b h k d -> b h q k', qk, qka)
323
  if there_is_a(mask):
324
+ mask = mask[:qk.shape[2], :qk.shape[2]]
325
+ qk = qk.masked_fill(mask.bool(), -torch.inf)
326
+
327
  x = qk.softmax(dim = -1)
328
  xa = qk.softmax(dim = -2)
329
  x = self.x(x)
 
332
  xa = einsum('b h j i, b h j d -> b h i d', xa, v)
333
  x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
334
  out = self.out(x)
335
+ return out
 
 
336
 
337
  class tgate(nn.Module):
338
  def __init__(self, dims, num_types=4):
339
  super().__init__()
340
+ self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
341
+ self.classifier = nn.Sequential(nn.Linear(dims, num_types), torch.nn.functional.Softmax(dim=-1))
342
  def forward(self, x):
343
  types = self.classifier(x)
344
  gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
 
352
  self.lna = nn.LayerNorm(dims, bias=False)
353
  self.atta = attentiona(dims, head)
354
  self.attb = attentionb(dims, head, max_iter=1)
355
+ self.attc = attentiond(dims, head)
356
 
357
+ self.tgate = tgate(dims, num_types=1)
358
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
359
 
360
+ def forward(self, x: Tensor, xa = None, mask = None):
361
+
362
+ out = self.atta(x, mask=mask)
363
+ if x.shape == out.shape:
364
+ x = x + out
365
+ else:
366
+ x = out
367
  if xa is not None:
368
+ x = x + self.atta(x, xa, mask=None)
369
  x = x + self.tgate(x)
370
+ x = x + self.mlp(self.lna(x))
371
  return x
372
 
373
  class processor(nn.Module):
374
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
375
+
376
  super(processor, self).__init__()
377
 
378
  self.ln = nn.LayerNorm(dims)
379
  self.token = nn.Embedding(vocab, dims)
380
  self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
381
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
382
+ self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
383
 
384
  act_fn = get_activation(act)
385
  self.encoder = nn.Sequential(
 
388
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
389
 
390
  self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
391
+ self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(2)])
392
 
393
+ mask = torch.triu(torch.ones(ctx, ctx), diagonal=1)
394
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
395
  self.register_buffer("mask", mask, persistent=False)
396
 
397
+ def forward(self, x, xa, xb, sequential=False, modal=False, kv_cache=None, blend=False) -> Tensor:
398
 
399
  if xa.dim() == 2:
400
  xa = xa.unsqueeze(0)
 
409
  xa = block(xa, mask=None)
410
  x = block(x, mask=self.mask)
411
  x = block(x, xa, mask=None)
412
+ if blend:
413
+ if sequential:
414
+ y = x
415
+ else:
416
+ a = torch.sigmoid(self.blend)
417
+ x = a * x + (1 - a) * y
418
 
419
  for block in chain(self.blockm or []):
420
  xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
421
  x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
422
+ if blend:
423
+ if sequential:
424
+ y = x
425
+ else:
426
+ a = torch.sigmoid(self.blend)
427
+ x = a * x + (1 - a) * y
428
 
429
  x = nn.functional.dropout(x, p=0.001, training=self.training)
430
  x = self.ln(x)
 
455
  def adjust_window(self, loss, ctx):
456
  self.win_size = ((ctx // self.param.head))
457
  if loss < self.best_loss:
458
+ win_size = (self.win_size * self.factor)
459
  else:
460
  win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
461
  self.win_size = win_size
 
463
  self.update(win_size)
464
  return win_size
465
 
466
+ def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
467
 
468
  x = input_ids
469
+ xa = pitch
470
+ xb = spectrogram
471
+
472
+ enc = {}
473
+ if spectrogram is not None:
474
+ enc["spectrogram"] = spectrogram
475
+ if waveform is not None:
476
+ enc["waveform"] = waveform
477
+ if pitch is not None:
478
+ enc["pitch"] = pitch
479
+
480
+ logits = self.processor(x, xa, xb)
481
  loss = None
482
  if labels is not None:
483
  loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
484
+ self.adjust_window(loss=loss.item(), ctx=xa.shape[1])
485
  return {"logits": logits, "loss": loss}
486
 
487
  def _init_weights(self, module):
 
540
  hooks.append(layer.v.register_forward_hook(save_to_cache))
541
  self.processor.apply(install_hooks)
542
  return cache, hooks
543
+
544
+ ### "pipeline"
545
+
546
+ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
547
+
548
+ if load_saved:
549
+ if cache_dir is None:
550
+ cache_dir = cache_dir
551
+ else:
552
+ cache_dir = cache_dir
553
+ os.makedirs(cache_dir, exist_ok=True)
554
+ cache_file_train = os.path.join(cache_dir, "train.arrow")
555
+ cache_file_test = os.path.join(cache_dir, "test.arrow")
556
+ if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
557
+ from datasets import Dataset
558
+ train_dataset = Dataset.load_from_disk(cache_file_train)
559
+ test_dataset = Dataset.load_from_disk(cache_file_test)
560
+ return train_dataset, test_dataset
561
+ def filter_func(x):
562
+ return (0 < len(x["transcription"]) < max_ctx and
563
+ len(x["audio"]["array"]) > 0 and
564
+ len(x["audio"]["array"]) < max_ctx * 160)
565
+
566
+ raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription")
567
+ raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000)
568
+
569
+ raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
570
+ raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
571
+ train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
572
+ test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
573
+ train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
574
+ test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
575
+ return train_dataset, test_dataset
576
+
577
+ def main():
578
+ token = ""
579
+ log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
580
+ os.makedirs(log_dir, exist_ok=True)
581
+ tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
582
+
583
+ extract_args = {
584
+ "waveform": True,
585
+ "spec": True,
586
+ "pitch_tokens": True,
587
+ "pitch": True,
588
+ "harmonics": False,
589
+ "aperiodics": False,
590
+ "phase_mod": False,
591
+ "crepe": False,
592
+ "sample_rate": 16000,
593
+ "hop_length": 256,
594
+ "mode": "mean",
595
+ "debug": False,
596
+ }
597
+
598
+ param = Dimensions(vocab=40000, mels=128, ctx=2048, dims=512, head=4, layer=4, act="swish")
599
+
600
+ train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
601
+ load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
602
+
603
+ model = Model(param).to('cuda')
604
+ print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
605
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
606
+
607
+ from functools import partial
608
+ metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
609
+
610
+ training_args = Seq2SeqTrainingArguments(
611
+ output_dir=log_dir,
612
+ per_device_train_batch_size=1,
613
+ per_device_eval_batch_size=1,
614
+ max_steps=100000,
615
+ eval_steps=1000,
616
+ save_steps=1000,
617
+ warmup_steps=1000,
618
+ logging_steps=100,
619
+ logging_dir=log_dir,
620
+ logging_strategy="steps",
621
+ eval_strategy="steps",
622
+ save_strategy="no",
623
+ report_to=["tensorboard"],
624
+ push_to_hub=False,
625
+ save_total_limit=1,
626
+ label_names=["labels"],
627
+ save_safetensors=False,
628
+ eval_on_start=False,
629
+ batch_eval_metrics=False,
630
+ disable_tqdm=False,
631
+ include_tokens_per_second=True,
632
+ include_num_input_tokens_seen=True,
633
+ learning_rate=0.00025,
634
+ weight_decay=0.025,
635
+ )
636
+
637
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
638
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
639
+
640
+ trainer = Seq2SeqTrainer(
641
+ args=training_args,
642
+ model=model,
643
+ train_dataset=train_dataset,
644
+ eval_dataset=test_dataset,
645
+ data_collator=DataCollator(tokenizer=tokenizer),
646
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
647
+ compute_metrics=metrics_fn,
648
+ optimizers=(optimizer, scheduler)
649
+ )
650
+
651
+ model.init_weights()
652
+ trainer.train()
653
+
654
+ if __name__ == "__main__":
655
+ main()