Sin2pi commited on
Commit
6296a84
·
verified ·
1 Parent(s): c0c3c62

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +308 -193
model_simple.py CHANGED
@@ -7,15 +7,29 @@ from torch import nn, Tensor, einsum
7
  import numpy as np
8
  from dataclasses import dataclass
9
  from einops import rearrange
10
- from datetime import datetime
11
- from echoutils import *
12
- from transformers.trainer_seq2seq import Seq2SeqTrainer
13
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
14
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
  dtype = torch.float32
16
  warnings.filterwarnings("ignore")
17
  logging.basicConfig(level=logging.ERROR)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def sinusoids(ctx, dims, max_tscale=10000):
20
  assert dims % 2 == 0
21
  pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
@@ -40,52 +54,159 @@ def get_activation(act: str) -> nn.Module:
40
  }
41
  return act_map.get(act, nn.GELU())
42
 
43
- def there_is_a(val):
44
- return val is not None
45
-
46
  @dataclass
47
  class Dimensions:
48
- vocab: int
49
  mels: int
50
  ctx: int
51
  dims: int
52
  head: int
 
53
  layer: int
54
  act: str
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class rotary(nn.Module):
57
  def __init__(self, dims, head):
58
  super(rotary, self).__init__()
59
  self.dims = dims
60
  self.head = head
61
  self.head_dim = dims // head
 
62
 
63
- self.theta = nn.Parameter((torch.tensor(16000, device=device, dtype=dtype)), requires_grad=True)
64
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
65
 
66
  def _compute_freqs_base(self):
67
  mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
68
  return 200 * mel_scale / 1000
69
 
70
- def forward(self, x) -> Tensor:
71
- freqs = (self.theta / 220.0) * self.freqs_base
72
- pos = torch.arange(x.shape[2], device=device, dtype=dtype)
73
- freqs = pos[:, None] * freqs
74
- freqs = torch.polar(torch.ones_like(freqs), freqs)
75
-
76
- x1 = x[..., :freqs.shape[-1]*2]
77
- x2 = x[..., freqs.shape[-1]*2:]
78
- orig_shape = x1.shape
79
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
80
- x1 = torch.view_as_complex(x1) * freqs
81
- x1 = torch.view_as_real(x1).flatten(-2)
82
- x1 = x1.view(orig_shape)
83
- return torch.cat([x1.type_as(x), x2], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  class attentiona(nn.Module):
86
  def __init__(self, dims: int, head: int):
87
  super().__init__()
88
-
89
  self.head = head
90
  self.dims = dims
91
  self.head_dim = dims // head
@@ -95,25 +216,23 @@ class attentiona(nn.Module):
95
  self.zmax = 1e-5
96
  self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False)
97
 
98
- self.q = nn.Linear(dims, dims, bias=False)
99
  self.kv = nn.Linear(dims, dims * 2, bias=False)
100
- self.out = nn.Linear(dims, dims, bias=False)
101
 
102
  self.lna = nn.LayerNorm(dims)
 
103
  self.rope = rotary(dims, head)
104
 
105
- def forward(self, x, xa = None, mask = None):
106
  zero = self.zero
107
 
108
- q = self.q(self.lna(x))
109
  k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
110
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
111
  scale = q.shape[-1] ** -0.5
112
 
113
- q = self.rope(q)
114
- k = self.rope(k)
115
-
116
- qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
117
 
118
  scale = torch.ones_like(k[:, :, :, 0])
119
  zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5)
@@ -134,7 +253,7 @@ class attentiona(nn.Module):
134
  return out
135
 
136
  class tgate(nn.Module):
137
- def __init__(self, dims, num_types=4):
138
  super().__init__()
139
  self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
140
  self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
@@ -145,78 +264,189 @@ class tgate(nn.Module):
145
  return cgate
146
 
147
  class residual(nn.Module):
148
- def __init__(self, dims: int, head: int, act: str = "silu"):
149
  super().__init__()
150
 
151
- self.lna = nn.LayerNorm(dims, bias=False)
152
  self.atta = attentiona(dims, head)
 
153
 
154
  self.tgate = tgate(dims, num_types=1)
155
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
156
 
157
- def forward(self, x: Tensor, xa = None, mask = None):
158
-
159
- out = self.atta(x, mask=mask)
160
- if x.shape == out.shape:
161
- x = x + out
162
- else:
163
- x = out
164
- if xa is not None:
165
- x = x + self.atta(x, xa, mask=None)
166
-
167
  x = x + self.tgate(x)
168
  x = x + self.mlp(self.lna(x))
 
 
169
  return x
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  class processor(nn.Module):
172
- def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu", modal=True):
173
  super(processor, self).__init__()
174
 
 
175
  self.ln = nn.LayerNorm(dims)
176
- self.token = nn.Embedding(vocab, dims)
177
  self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
178
 
179
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
180
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
181
-
182
- act_fn = get_activation(act)
183
  self.encoder = nn.Sequential(
184
- nn.Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
185
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
186
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
187
 
188
- self.block = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
 
 
 
189
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
190
  self.register_buffer("mask", mask, persistent=False)
191
 
192
- def forward(self, x, xa, enc=None, sequential=False, modal=True, blend=False, kv_cache=None) -> Tensor:
193
- mask = self.mask[:x.shape[1], :x.shape[1]]
 
 
 
194
 
 
 
 
195
  offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
196
  x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
197
 
198
  xa = self.encoder(xa).permute(0, 2, 1)
199
  xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
200
 
201
- for block in chain(self.block or []):
202
- xa = block(xa, mask=None)
203
- x = block(x, mask=mask)
204
- x = block(x, xa, mask=None)
205
- if blend:
206
- if sequential:
207
- y = x
208
- else:
209
- a = torch.sigmoid(self.blend)
210
- x = a * x + (1 - a) * y
211
-
212
- xm = block(torch.cat([x, xa], dim=1), mask=mask) if modal else None
213
- x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
214
- if blend:
215
- if sequential:
216
- y = x
217
- else:
218
- a = torch.sigmoid(self.blend)
219
- x = a * x + (1 - a) * y
 
 
220
 
221
  x = nn.functional.dropout(x, p=0.001, training=self.training)
222
  x = self.ln(x)
@@ -228,18 +458,19 @@ class Model(nn.Module):
228
  super().__init__()
229
  self.param = param
230
  self.processor = processor(
231
- vocab=param.vocab,
232
  mels=param.mels,
233
  ctx=param.ctx,
234
  dims=param.dims,
235
  head=param.head,
 
236
  layer=param.layer,
237
  act=param.act)
238
 
239
  def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
240
 
241
  x = input_ids
242
- xa = pitch
243
 
244
  enc = {}
245
  if spectrogram is not None:
@@ -248,6 +479,8 @@ class Model(nn.Module):
248
  enc["waveform"] = waveform
249
  if pitch is not None:
250
  enc["pitch"] = pitch
 
 
251
 
252
  logits = self.processor(x, xa, enc)
253
  loss = None
@@ -259,7 +492,7 @@ class Model(nn.Module):
259
  def _init_weights(self, module):
260
  self.init_counts = {
261
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
262
- "Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
263
  for name, module in self.named_modules():
264
  if isinstance(module, nn.RMSNorm):
265
  nn.init.ones_(module.weight)
@@ -295,121 +528,3 @@ class Model(nn.Module):
295
  for module_type, count in self.init_counts.items():
296
  if count > 0:
297
  print(f"{module_type}: {count}")
298
-
299
- def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
300
-
301
- if load_saved:
302
- if cache_dir is None:
303
- cache_dir = cache_dir
304
- else:
305
- cache_dir = cache_dir
306
-
307
- os.makedirs(cache_dir, exist_ok=True)
308
- cache_file_train = os.path.join(cache_dir, "train.arrow")
309
- cache_file_test = os.path.join(cache_dir, "test.arrow")
310
-
311
- if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
312
- from datasets import Dataset
313
- train_dataset = Dataset.load_from_disk(cache_file_train)
314
- test_dataset = Dataset.load_from_disk(cache_file_test)
315
- return train_dataset, test_dataset
316
-
317
- def filter_func(x):
318
- return (0 < len(x["transcription"]) < max_ctx and
319
- len(x["audio"]["array"]) > 0 and
320
- len(x["audio"]["array"]) < max_ctx * 160)
321
-
322
- raw_train = load_dataset(
323
- "google/fleurs", "en_us", token=token, split="train", streaming=streaming).take(1000)
324
- raw_test = load_dataset(
325
- "google/fleurs", "en_us", token=token, split="test", streaming=streaming).take(100)
326
-
327
- raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
328
- raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
329
- train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
330
- test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
331
- train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
332
- test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
333
- return train_dataset, test_dataset
334
-
335
- def main():
336
- token = ""
337
- log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
338
- os.makedirs(log_dir, exist_ok=True)
339
- tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
340
-
341
- extract_args = {
342
- "waveform": False,
343
- "spec": False,
344
- "pitch_tokens": False,
345
- "pitch": True,
346
- "harmonics": False,
347
- "aperiodics": False,
348
- "phase_mod": False,
349
- "crepe": False,
350
- "sample_rate": 16000,
351
- "hop_length": 256,
352
- "mode": "mean",
353
- "debug": False,
354
- }
355
-
356
- param = Dimensions(vocab=40000, mels=128, ctx=2048, dims=512, head=4, layer=4, act="swish")
357
-
358
- train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
359
- load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
360
-
361
- model = Model(param).to('cuda')
362
- print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
363
- print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
364
-
365
- from functools import partial
366
- metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
367
-
368
- training_args = Seq2SeqTrainingArguments(
369
- output_dir=log_dir,
370
- per_device_train_batch_size=1,
371
- per_device_eval_batch_size=1,
372
- max_steps=1000,
373
- eval_steps=100,
374
- save_steps=100,
375
- warmup_steps=10,
376
- logging_steps=10,
377
- logging_dir=log_dir,
378
- logging_strategy="steps",
379
- eval_strategy="steps",
380
- save_strategy="no",
381
- report_to=["tensorboard"],
382
- push_to_hub=False,
383
- save_total_limit=1,
384
- label_names=["labels"],
385
- save_safetensors=False,
386
- eval_on_start=False,
387
- batch_eval_metrics=False,
388
- disable_tqdm=False,
389
- include_tokens_per_second=True,
390
- include_num_input_tokens_seen=True,
391
- learning_rate=0.00025,
392
- weight_decay=0.025,
393
- )
394
-
395
- optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-10, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
396
-
397
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
398
-
399
- trainer = Seq2SeqTrainer(
400
- args=training_args,
401
- model=model,
402
- train_dataset=train_dataset,
403
- eval_dataset=test_dataset,
404
- data_collator=DataCollator(tokenizer=tokenizer),
405
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
406
- compute_metrics=metrics_fn,
407
- optimizers=(optimizer, scheduler)
408
- )
409
-
410
- model.init_weights()
411
- trainer.train()
412
-
413
- if __name__ == "__main__":
414
- main()
415
-
 
7
  import numpy as np
8
  from dataclasses import dataclass
9
  from einops import rearrange
10
+
 
 
 
11
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
  dtype = torch.float32
13
  warnings.filterwarnings("ignore")
14
  logging.basicConfig(level=logging.ERROR)
15
 
16
+ def scaled_relu(x, sequence_length):
17
+ relu_output = torch.relu(x)
18
+ return relu_output / sequence_length
19
+
20
+ def taylor_softmax(x, order=2):
21
+ tapprox = 1.0
22
+ for i in range(1, order + 1):
23
+ factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
24
+ tapprox += x**i / factorial_i
25
+ return tapprox / torch.sum(tapprox, dim=-1, keepdim=True)
26
+
27
+ def there_is_a(a):
28
+ return a is not None
29
+
30
+ def AorB(a, b):
31
+ return a if there_is_a(a) else b
32
+
33
  def sinusoids(ctx, dims, max_tscale=10000):
34
  assert dims % 2 == 0
35
  pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
 
54
  }
55
  return act_map.get(act, nn.GELU())
56
 
 
 
 
57
  @dataclass
58
  class Dimensions:
59
+ tokens: int
60
  mels: int
61
  ctx: int
62
  dims: int
63
  head: int
64
+ head_dim: int
65
  layer: int
66
  act: str
67
 
68
+ def vectorized_taylor_sine(x, order=5):
69
+ original_shape = x.shape
70
+ x = x.flatten(0, -2)
71
+ exponents = torch.arange(1, order + 1, 2, device=x.device, dtype=torch.float32)
72
+ x_powers = x.unsqueeze(-1) ** exponents
73
+ factorials = torch.exp(torch.lgamma(exponents + 1))
74
+ signs = (-1)**(torch.arange(0, len(exponents), device=x.device, dtype=torch.float32))
75
+ terms = signs * x_powers / factorials
76
+ result = terms.sum(dim=-1)
77
+ return result.view(original_shape)
78
+
79
+ def vectorized_taylor_cosine(x, order=5):
80
+ original_shape = x.shape
81
+ x = x.flatten(0, -2)
82
+ exponents = torch.arange(0, order + 1, 2, device=x.device, dtype=torch.float32)
83
+ x_powers = x.unsqueeze(-1) ** exponents
84
+ factorials = torch.exp(torch.lgamma(exponents + 1))
85
+ signs = (-1)**(torch.arange(0, len(exponents), device=x.device, dtype=torch.float32))
86
+ terms = signs * x_powers / factorials
87
+ result = terms.sum(dim=-1)
88
+ return result.view(original_shape)
89
+
90
  class rotary(nn.Module):
91
  def __init__(self, dims, head):
92
  super(rotary, self).__init__()
93
  self.dims = dims
94
  self.head = head
95
  self.head_dim = dims // head
96
+ self.taylor_order = 10
97
 
98
+ self.theta = nn.Parameter((torch.tensor(360000, device=device, dtype=dtype)), requires_grad=False)
99
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
100
 
101
  def _compute_freqs_base(self):
102
  mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
103
  return 200 * mel_scale / 1000
104
 
105
+ def forward(self, x) -> torch.Tensor:
106
+ positions = (torch.arange(0, x.shape[2], device=x.device))
107
+ freqs = (self.theta / 220.0) * self.freqs_base
108
+ freqs = positions[:, None] * freqs
109
+ freqs_rescaled = (freqs + torch.pi) % (2 * torch.pi) - torch.pi
110
+
111
+ with torch.autocast(device_type="cuda", enabled=False):
112
+ cos = vectorized_taylor_cosine(freqs_rescaled, order=self.taylor_order)
113
+ sin = vectorized_taylor_sine(freqs_rescaled, order=self.taylor_order)
114
+ rotary_dim = cos.shape[-1]
115
+ x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
116
+ x_embed = (x_rot * cos) + (rotate_half(x_rot) * sin)
117
+ x_embed = torch.cat([x_embed, x_pass], dim=-1)
118
+ return x_embed.type_as(x)
119
+
120
+ def taylor_sine(x, order=5):
121
+ result = torch.zeros_like(x)
122
+ for i in range(order + 1):
123
+ if i % 2 == 1:
124
+ term = x**i / torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
125
+ if (i // 2) % 2 == 1:
126
+ result -= term
127
+ else:
128
+ result += term
129
+ return result
130
+
131
+ def taylor_cosine(x, order=5):
132
+ result = torch.zeros_like(x)
133
+ for i in range(order + 1):
134
+ if i % 2 == 0:
135
+ term = x**i / torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
136
+ if (i // 2) % 2 == 1:
137
+ result -= term
138
+ else:
139
+ result += term
140
+ return result
141
+
142
+ class rotarya(nn.Module):
143
+ def __init__(self, dims, head):
144
+ super(rotary, self).__init__()
145
+ self.dims = dims
146
+ self.head = head
147
+ self.head_dim = dims // head
148
+ self.taylor_order = 5
149
+
150
+ self.theta = nn.Parameter((torch.tensor(1600, device=device, dtype=dtype)), requires_grad=False)
151
+ self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
152
+
153
+ def _compute_freqs_base(self):
154
+ mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
155
+ return 200 * mel_scale / 1000
156
+
157
+ def forward(self, x) -> torch.Tensor:
158
+
159
+ positions = (torch.arange(0, x.shape[2], device=x.device))
160
+ freqs = (self.theta / 220.0) * self.freqs_base
161
+ freqs = positions[:, None] * freqs
162
+ freqs = (freqs + torch.pi) % (2 * torch.pi) - torch.pi
163
+ with torch.autocast(device_type="cuda", enabled=False):
164
+ cos = taylor_cosine(freqs, order=self.taylor_order)
165
+ sin = taylor_sine(freqs, order=self.taylor_order)
166
+ rotary_dim = cos.shape[-1]
167
+ x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
168
+ x_embed = (x_rot * cos) + (rotate_half(x_rot) * sin)
169
+ x_embed = torch.cat([x_embed, x_pass], dim=-1)
170
+ return x_embed.type_as(x)
171
+
172
+ def rotate_half(x):
173
+ x1 = x[..., : x.shape[-1] // 2]
174
+ x2 = x[..., x.shape[-1] // 2 :]
175
+ return torch.cat((-x2, x1), dim=-1)
176
+
177
+ # class rotary(nn.Module):
178
+ # def __init__(self, dims, head):
179
+ # super(rotary, self).__init__()
180
+ # self.dims = dims
181
+ # self.head = head
182
+ # self.head_dim = dims // head
183
+
184
+ # self.theta = nn.Parameter((torch.tensor(1600, device=device, dtype=dtype)), requires_grad=False)
185
+ # # self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
186
+
187
+ # def _compute_freqs_base(self):
188
+ # mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
189
+ # return 200 * mel_scale / 1000
190
+
191
+ # def forward(self, x) -> Tensor:
192
+ # positions = (torch.arange(0, x.shape[2], device=x.device))
193
+ # freqs = (self.theta / 220.0) * self._compute_freqs_base()
194
+ # freqs = positions[:, None] * freqs
195
+
196
+ # with torch.autocast(device_type="cuda", enabled=False):
197
+ # freqs = torch.polar(torch.ones_like(freqs), freqs)
198
+ # x1 = x[..., :freqs.shape[-1]*2]
199
+ # x2 = x[..., freqs.shape[-1]*2:]
200
+ # orig_shape = x1.shape
201
+ # x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
202
+ # x1 = torch.view_as_complex(x1) * freqs
203
+ # x1 = torch.view_as_real(x1).flatten(-2)
204
+ # x1 = x1.view(orig_shape)
205
+ # return torch.cat([x1.type_as(x), x2], dim=-1)
206
 
207
  class attentiona(nn.Module):
208
  def __init__(self, dims: int, head: int):
209
  super().__init__()
 
210
  self.head = head
211
  self.dims = dims
212
  self.head_dim = dims // head
 
216
  self.zmax = 1e-5
217
  self.zero = nn.Parameter(torch.tensor(1e-4, device=device, dtype=dtype), requires_grad=False)
218
 
219
+ self.q = nn.Linear(dims, dims)
220
  self.kv = nn.Linear(dims, dims * 2, bias=False)
221
+ self.out = nn.Linear(dims, dims)
222
 
223
  self.lna = nn.LayerNorm(dims)
224
+ self.lnb = nn.LayerNorm(dims // head)
225
  self.rope = rotary(dims, head)
226
 
227
+ def forward(self, x, xa = None, mask = None, positions = None):
228
  zero = self.zero
229
 
230
+ q = self.q(x)
231
  k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
232
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
233
  scale = q.shape[-1] ** -0.5
234
 
235
+ qk = einsum('b h k d, b h q d -> b h k q', self.lnb(q), self.lnb(k)) * scale
 
 
 
236
 
237
  scale = torch.ones_like(k[:, :, :, 0])
238
  zero = torch.clamp(F.softplus(zero), 1e-6, 1e-5)
 
253
  return out
254
 
255
  class tgate(nn.Module):
256
+ def __init__(self, dims, num_types=1):
257
  super().__init__()
258
  self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
259
  self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
 
264
  return cgate
265
 
266
  class residual(nn.Module):
267
+ def __init__(self, dims: int, head: int, layer = 2, act = "silu"):
268
  super().__init__()
269
 
270
+ self.lna = nn.LayerNorm(dims, bias=False)
271
  self.atta = attentiona(dims, head)
272
+ self.dsl = skip_layer(dims, head, layer=2)
273
 
274
  self.tgate = tgate(dims, num_types=1)
275
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
276
 
277
+ def forward(self, x: Tensor, xa = None, mask = None, positions=None):
278
+ # log = {}
279
+ x = x + self.atta(self.lna(x), xa=xa, mask=mask)
280
+ x, _ = self.dsl(self.lna(x), xa=xa, mask=mask) # _ outputs logs for jumps
 
 
 
 
 
 
281
  x = x + self.tgate(x)
282
  x = x + self.mlp(self.lna(x))
283
+ # print(results['jumps'])
284
+ # log['jumps'] = l
285
  return x
286
+
287
+ class skip_layer(nn.Module):
288
+ def __init__(self, dims, head, layer, threshold=0.1):
289
+ super().__init__()
290
+ self.layers = nn.ModuleList()
291
+ self.layer = layer
292
 
293
+ self.threshold = threshold
294
+ self.dims = dims
295
+ self.head = head
296
+ self.head_dim = dims // head
297
+
298
+ self.attention_module = attentiona(dims, head)
299
+ self.node_predictors = nn.ModuleList([
300
+ nn.Sequential(
301
+ nn.LayerNorm(dims),
302
+ nn.Linear(dims, 1),
303
+ nn.Sigmoid()
304
+ ) for _ in range(layer)
305
+ ])
306
+
307
+ for i in range(layer):
308
+ self.layers.append(nn.ModuleDict({
309
+ 'ln': nn.LayerNorm(dims),
310
+ 'gate': nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()),
311
+ 'adapter': nn.Linear(dims, dims) if i % 2 == 0 else None
312
+ }))
313
+
314
+ self.policy_net = nn.Sequential(
315
+ nn.Linear(dims, 128),
316
+ nn.ReLU(),
317
+ nn.Linear(128, 3))
318
+
319
+ self.jump_weights = nn.Parameter(torch.tensor([0.1, 0.05, 0.01]))
320
+
321
+ n_mlp = dims * 4
322
+ self.mlp_gate = nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid())
323
+ self.mlp = nn.Sequential(nn.Linear(dims, n_mlp), nn.GELU(), nn.Linear(n_mlp, dims))
324
+ self.mlp_ln =nn.LayerNorm(dims)
325
+ self.working_memory = nn.Parameter(torch.zeros(1, 1, dims))
326
+ self.memory_gate = nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid())
327
+
328
+ def _calculate_shared_attention(self, x, mask=None):
329
+ return self.attention_module(x, xa=x, mask=None)
330
+
331
+ def predict_node_importance(self, x, layer_idx):
332
+ importance = self.node_predictors[layer_idx](x)
333
+ return (importance > self.threshold).float()
334
+
335
+ def forward(self, x, xa=None, mask=None):
336
+ batch, ctx = x.shape[:2]
337
+
338
+ working_memory = self.working_memory.expand(batch, -1, -1)
339
+ original_x = x
340
+ pooled_representation = x.mean(dim=1)
341
+ policy_logits = self.policy_net(pooled_representation)
342
+ policy = F.softmax(policy_logits, dim=-1)
343
+
344
+ jump_history = []
345
+ i = 0
346
+ while i < self.layer:
347
+ layer = self.layers[i]
348
+ node_importance = self.predict_node_importance(x, i)
349
+ if node_importance.mean() < 0.2 and i > 0:
350
+ i += 1
351
+ jump_history.append(i)
352
+ continue
353
+
354
+ norm_x = layer['ln'](x)
355
+ importance_mask_base = node_importance.unsqueeze(1).contiguous()
356
+ combined_custom_mask = None
357
+ if mask is None:
358
+ combined_custom_mask = importance_mask_base
359
+ else:
360
+ combined_custom_mask = mask.contiguous() * importance_mask_base
361
+
362
+ if node_importance.mean() > 0.3:
363
+ attn_output = self._calculate_shared_attention(norm_x, mask=combined_custom_mask.contiguous())
364
+ if layer['adapter'] is not None:
365
+ attn_output = layer['adapter'](attn_output)
366
+
367
+ gate_value = layer['gate'](norm_x)
368
+ x = x + gate_value * attn_output
369
+ memory_gate = self.memory_gate(x)
370
+ working_memory = memory_gate * working_memory + (1 - memory_gate) * x.mean(dim=1, keepdim=True)
371
+
372
+ jump_prob = policy[:, 1] if i < self.layer - 1 else torch.zeros_like(policy[:, 1])
373
+ should_jump = (torch.rand_like(jump_prob) < jump_prob).any()
374
+
375
+ if should_jump:
376
+ jump_length = torch.multinomial(policy, 1)[:, 0].max().item() + 1
377
+ i_next = min(i + jump_length, self.layer - 1)
378
+ skip_weight = self.jump_weights[min(jump_length-1, 2)]
379
+ x = x + skip_weight * original_x + (1-skip_weight) * working_memory
380
+ i = i_next
381
+ jump_history.append(i)
382
+ else:
383
+ i += 1
384
+
385
+ mlp_importance = self.mlp_gate(x)
386
+ mlp_output = self.mlp(self.mlp_ln(x))
387
+ x = x + mlp_importance * mlp_output
388
+ return x, {'jumps': jump_history}
389
+
390
  class processor(nn.Module):
391
+ def __init__(self, tokens, mels, ctx, dims, head, head_dim, layer, act):
392
  super(processor, self).__init__()
393
 
394
+ act_fn = get_activation(act)
395
  self.ln = nn.LayerNorm(dims)
396
+ self.token = nn.Embedding(tokens, dims)
397
  self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
398
 
399
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
400
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
401
+
 
402
  self.encoder = nn.Sequential(
403
+ nn.Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
404
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
405
  nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
406
 
407
+ modal = False
408
+ self.block = nn.ModuleList([residual(dims, head, layer, act_fn) for _ in range(layer)]) if modal else None
409
+
410
+ self.res = residual(dims, head, layer, act_fn)
411
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
412
  self.register_buffer("mask", mask, persistent=False)
413
 
414
+ def init_memory(self, batch):
415
+ return torch.zeros(batch, 1, self.dims).to(next(self.parameters()).device)
416
+
417
+ def update_memory(self, x, working_memory):
418
+ return (x + working_memory) / 2
419
 
420
+ def forward(self, x, xa, enc=None, sequential=False, modal=False, blend=False, kv_cache=None) -> Tensor:
421
+
422
+ mask = self.mask[:x.shape[1], :x.shape[1]]
423
  offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
424
  x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
425
 
426
  xa = self.encoder(xa).permute(0, 2, 1)
427
  xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
428
 
429
+ xa = self.res(xa, None, None)
430
+ x = self.res(x, None, mask)
431
+ x = self.res(x, xa, None)
432
+
433
+ if blend:
434
+ if sequential:
435
+ y = x
436
+ else:
437
+ a = torch.sigmoid(self.blend)
438
+ x = a * x + (1 - a) * y
439
+
440
+ if modal:
441
+ for block in chain(self.block or []):
442
+ xm = block(torch.cat([x, xa], dim=1), mask=mask) if modal else None
443
+ x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
444
+ if blend:
445
+ if sequential:
446
+ y = x
447
+ else:
448
+ a = torch.sigmoid(self.blend)
449
+ x = a * x + (1 - a) * y
450
 
451
  x = nn.functional.dropout(x, p=0.001, training=self.training)
452
  x = self.ln(x)
 
458
  super().__init__()
459
  self.param = param
460
  self.processor = processor(
461
+ tokens=param.tokens,
462
  mels=param.mels,
463
  ctx=param.ctx,
464
  dims=param.dims,
465
  head=param.head,
466
+ head_dim=param.head_dim,
467
  layer=param.layer,
468
  act=param.act)
469
 
470
  def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
471
 
472
  x = input_ids
473
+ xa = AorB(pitch, spectrogram)
474
 
475
  enc = {}
476
  if spectrogram is not None:
 
479
  enc["waveform"] = waveform
480
  if pitch is not None:
481
  enc["pitch"] = pitch
482
+ if pitch_tokens is not None:
483
+ enc["ptokens"] = pitch_tokens
484
 
485
  logits = self.processor(x, xa, enc)
486
  loss = None
 
492
  def _init_weights(self, module):
493
  self.init_counts = {
494
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
495
+ "Conv2d": 0, "processor": 0, "attentiona": 0, "Residual": 0}
496
  for name, module in self.named_modules():
497
  if isinstance(module, nn.RMSNorm):
498
  nn.init.ones_(module.weight)
 
528
  for module_type, count in self.init_counts.items():
529
  if count > 0:
530
  print(f"{module_type}: {count}")