Sin2pi commited on
Commit
b62df1e
·
verified ·
1 Parent(s): fd4647c

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +468 -168
model_simple.py CHANGED
@@ -3,18 +3,134 @@ import warnings
3
  import logging
4
  from itertools import chain
5
  import torch
6
- from torch import nn, Tensor
7
- from typing import Optional, Dict
 
 
8
  import numpy as np
9
- from datetime import datetime
10
  from dataclasses import dataclass
 
11
  from torch.nn.functional import scaled_dot_product_attention
12
- from echoutils import *
 
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.float32
15
  warnings.filterwarnings("ignore")
16
  logging.basicConfig(level=logging.ERROR)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @dataclass
19
  class Dimensions:
20
  vocab: int
@@ -25,22 +141,47 @@ class Dimensions:
25
  layer: int
26
  act: str
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class rotary(nn.Module):
29
  def __init__(self, dims, head):
30
  super(rotary, self).__init__()
31
  self.dims = dims
32
  self.head = head
33
  self.head_dim = dims // head
34
- self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
 
35
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
36
 
37
  def _compute_freqs_base(self):
38
  mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
39
  return 200 * mel_scale / 1000
40
 
41
- def forward(self, x, ctx) -> Tensor:
42
- freqs = (self.theta / 220.0) * self.freqs_base
43
- pos = torch.arange(ctx, device=device, dtype=dtype)
 
44
  freqs = pos[:, None] * freqs
45
  freqs=torch.polar(torch.ones_like(freqs), freqs)
46
 
@@ -53,28 +194,6 @@ class rotary(nn.Module):
53
  x1 = x1.view(orig_shape)
54
  return torch.cat([x1.type_as(x), x2], dim=-1)
55
 
56
- def shape(dims, head, q, k, v):
57
- head_dim = dims // head
58
- scale = head_dim ** -0.25
59
- q = q * scale
60
- k = k * scale
61
- v = v
62
- def _shape(tensor):
63
- return tensor.view(*tensor.shape[:2], head, -1).permute(0, 2, 1, 3).contiguous()
64
- return _shape(q), _shape(k), _shape(v)
65
-
66
- def qkv_init(dims: int, head: int):
67
- head_dim = dims // head
68
- q = nn.Linear(dims, dims)
69
- k = nn.Linear(dims, dims, bias=False)
70
- v = nn.Linear(dims, dims)
71
- o = nn.Linear(dims, dims)
72
- lna = nn.LayerNorm(dims, bias=False)
73
- lnb = nn.LayerNorm(dims, bias=False)
74
- lnc = nn.LayerNorm(head_dim, bias=False)
75
- lnd = nn.LayerNorm(head_dim, bias=False)
76
- return q, k, v, o, lna, lnb, lnc, lnd
77
-
78
  def calculate_attention(q, k, v, mask=None, temp=1.0):
79
  scaled_q = q
80
  if temp != 1.0 and temp > 0:
@@ -82,94 +201,136 @@ def calculate_attention(q, k, v, mask=None, temp=1.0):
82
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
83
  return out
84
 
 
 
 
 
 
 
 
 
 
85
  class LocalOut(nn.Module):
86
  def __init__(self, dims: int, head: int):
87
  super().__init__()
88
- head_dim = dims // head
89
- self.query_module = nn.Linear(head_dim, head_dim)
90
- self.key_module = nn.Linear(head_dim, head_dim)
91
- self.value_module = nn.Linear(head_dim, head_dim)
92
- self.out_proj = nn.Linear(head_dim, head_dim)
 
93
 
94
  def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
95
  batch, _, ctx, _ = attn_output.shape
96
- return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
97
 
98
  class attentionb(nn.Module):
99
- def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
 
100
  super(attentionb, self).__init__()
101
- self.q, self.k, self.v, self.o, self.lna, self.lnb, self.lnc, self.lnd = qkv_init(dims, head)
102
- self.dims = dims
103
  self.head = head
 
 
 
 
 
 
 
 
 
 
 
 
104
  self.max_iter = max_iter
105
  self.threshold = nn.Parameter(torch.tensor(threshold))
106
  self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
107
  self.factor = nn.Parameter(torch.tensor(factor))
108
- self.alocal = LocalOut(dims, head)
109
-
110
- def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
111
- q = self.q(self.lna(x))
112
- k = self.k(self.lnb(x if xa is None else xa))
113
- v = self.v(self.lnb(x if xa is None else xa))
114
- q, k, v = shape(self.dims, self.head, q, k, v)
115
-
116
- iteration = 0
117
- temp = self.temp.item()
118
- prev_out = torch.zeros_like(q)
119
- attn_out = torch.zeros_like(q)
120
- threshold = self.threshold.item()
121
- factor = self.factor.item()
122
- qcur = q
123
-
124
- while iteration < self.max_iter:
125
- eff_span = min(qcur.shape[1], k.shape[1])
126
- if xa is not None:
127
- eff_span = min(eff_span, xa.shape[1])
128
- if eff_span == 0:
129
- break
130
-
131
- qiter = qcur[:, :, :eff_span, :]
132
- kiter = k[:, :, :eff_span, :]
133
- viter = v[:, :, :eff_span, :]
134
- q = self.alocal.query_module(qiter)
135
- k = self.alocal.key_module(kiter)
136
- v = self.alocal.value_module(viter)
137
-
138
- iter_mask = None
139
- if mask is not None:
140
- if mask.dim() == 4:
141
- iter_mask = mask[:, :, :eff_span, :eff_span]
142
- elif mask.dim() == 2:
143
- iter_mask = mask[:eff_span, :eff_span]
144
-
145
- attn_iter = calculate_attention(
146
- self.lnc(q), self.lnd(k), v,
147
- mask=iter_mask, temp=temp)
148
-
149
- iter_out = torch.zeros_like(qcur)
150
- iter_out[:, :, :eff_span, :] = attn_iter
151
- diff = torch.abs(iter_out - prev_out).mean()
152
- dthresh = threshold + factor * diff
153
- if diff < dthresh and iteration > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  attn_out = iter_out
155
- break
156
-
157
- prev_out = iter_out.clone()
158
- qcur = qcur + iter_out
159
- attn_out = iter_out
160
- iteration += 1
161
- temp += 0.005
162
 
163
- output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
164
- return self.o(output), None
165
 
166
- def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
 
 
167
 
168
- batch, ctx, dims = x.shape
169
  output = torch.zeros_like(x)
170
- num_win = (ctx + win_size - 1) // win_size
171
 
172
- for i in range(num_win):
173
  qstart = i * win_size
174
  qend = min(qstart + win_size, ctx)
175
  win_qlen = qend - qstart
@@ -188,50 +349,151 @@ class attentionb(nn.Module):
188
  elif mask.dim() == 2:
189
  win_mask = mask[qstart:qend, kstart:kend]
190
 
191
- attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask)
192
  output[:, qstart:qend, :] = attn_out
193
  return output
194
 
195
- def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None,
196
- use_sliding_win: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor:
197
- if use_sliding_win:
198
- return self._slide_win_local(x, win_size, span_len, mask)
199
  else:
200
- output, _ = self._focus(x, xa, mask)
201
  return output
202
 
203
  class attentiona(nn.Module):
204
- def __init__(self, dims: int, head: int):
205
- super(attentiona, self).__init__()
206
- self.q, self.k, self.v, self.o, self.lna, self.lnb, self.lnc, self.lnd = qkv_init(dims, head)
207
  self.dims = dims
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  self.head = head
209
- self.rope = rotary(dims=dims, head=head)
210
- def forward(self, x: Tensor, xa = None, mask = None):
211
- q = self.q(self.lna(x))
212
- k = self.k(self.lnb(x if xa is None else xa))
213
- v = self.v(self.lnb(x if xa is None else xa))
214
- q, k, v = shape(self.dims, self.head, q, k, v)
215
- q = self.rope(q, q.shape[2])
216
- k = self.rope(k, k.shape[2])
217
- a = scaled_dot_product_attention(self.lnc(q), self.lnd(k), v, is_causal=mask is not None and q.shape[1] > 1)
218
- out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
219
- return self.o(out)
220
-
221
- class Residual(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def __init__(self, dims: int, head: int, act: str = "silu"):
223
  super().__init__()
224
 
225
- self.lna = nn.LayerNorm(dims, bias=False)
226
- self.attna = attentiona(dims, head)
227
- self.attnb = attentionb(dims, head, max_iter=3)
228
- self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
229
-
230
- def forward(self, x, xa = None, mask = None) -> Tensor:
231
- x = x + self.attna(self.lna(x), mask=mask)
 
 
 
 
 
 
 
 
 
232
  if xa is not None:
233
- x = x + self.attna(self.lna(x), xa, mask=None)
234
- x = x + self.attnb(self.lna(x), xa, mask=None, use_sliding_win=True, win_size=256, span_len=512)
 
 
235
  x = x + self.mlp(self.lna(x))
236
  return x
237
 
@@ -239,50 +501,48 @@ class processor(nn.Module):
239
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
240
  super(processor, self).__init__()
241
 
242
- self.lna = nn.LayerNorm(dims)
243
- self.lnb = nn.LayerNorm(dims)
244
- self.lnc = nn.LayerNorm(dims)
245
- self.token_emb = nn.Embedding(vocab, dims)
246
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
247
- self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
248
 
249
  act_fn = get_activation(act)
250
- self.audio_enc = nn.Sequential(
251
- Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
252
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
253
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
254
 
255
- self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
 
256
 
257
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
258
  self.register_buffer("mask", mask, persistent=False)
259
 
260
- def forward(self, x, xa, sequential=False, modal=False) -> Tensor:
261
 
262
- x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
263
- xa = self.audio_enc(xa).permute(0, 2, 1)
264
- xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
265
 
266
- for b in chain(self.bA or []):
267
- xa = b(x=xa, xa=None, mask=None)
268
- x = b(x=x, xa=None, mask=self.mask)
269
- x = b(x=x, xa=xa, mask=None)
270
- xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
271
- x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
 
 
 
 
 
 
 
 
272
 
273
  x = nn.functional.dropout(x, p=0.001, training=self.training)
274
- x = self.lnc(x)
275
- x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
276
- return x
277
 
278
- def init_weights(self):
279
- print("Initializing model weights...")
280
- self.apply(self._init_weights)
281
- print("Initialization summary:")
282
- for module_type, count in self.init_counts.items():
283
- if count > 0:
284
- print(f"{module_type}: {count}")
285
-
286
  class Model(nn.Module):
287
  def __init__(self, param: Dimensions):
288
  super().__init__()
@@ -296,14 +556,34 @@ class Model(nn.Module):
296
  layer=param.layer,
297
  act=param.act)
298
 
299
- def forward(self,
300
- labels=None, input_ids=None, pitch: Optional[torch.Tensor]=None) -> Dict[str, Optional[torch.Tensor]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  x = input_ids
302
- xa = pitch if pitch is not None else torch.zeros(1, 1, self.param.mels, device=device, dtype=dtype)
303
  logits = self.processor(x, xa)
304
  loss = None
305
  if labels is not None:
306
- loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
 
307
  return {"logits": logits, "loss": loss}
308
 
309
  def _init_weights(self, module):
@@ -311,26 +591,29 @@ class Model(nn.Module):
311
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
312
  "Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
313
  for name, module in self.named_modules():
314
- if isinstance(module, RMSNorm):
315
  nn.init.ones_(module.weight)
316
  self.init_counts["RMSNorm"] += 1
 
 
 
317
  elif isinstance(module, nn.Linear):
318
  if module.weight is not None:
319
  nn.init.xavier_uniform_(module.weight)
320
  if module.bias is not None:
321
  nn.init.zeros_(module.bias)
322
  self.init_counts["Linear"] += 1
323
- elif isinstance(module, Conv1d):
324
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
325
  if module.bias is not None:
326
  nn.init.zeros_(module.bias)
327
  self.init_counts["Conv1d"] += 1
328
- elif isinstance(module, Conv2d):
329
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
330
  if module.bias is not None:
331
  nn.init.zeros_(module.bias)
332
  self.init_counts["Conv2d"] += 1
333
- elif isinstance(module, Residual):
334
  self.init_counts["Residual"] += 1
335
  elif isinstance(module, processor):
336
  self.init_counts["processor"] += 1
@@ -342,3 +625,20 @@ class Model(nn.Module):
342
  for module_type, count in self.init_counts.items():
343
  if count > 0:
344
  print(f"{module_type}: {count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
  from itertools import chain
5
  import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor, einsum
8
+ from typing import Optional, List, Tuple, Union
9
+
10
  import numpy as np
 
11
  from dataclasses import dataclass
12
+
13
  from torch.nn.functional import scaled_dot_product_attention
14
+ from einops import rearrange
15
+
16
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
  dtype = torch.float32
18
  warnings.filterwarnings("ignore")
19
  logging.basicConfig(level=logging.ERROR)
20
 
21
+ def sinusoids(ctx, dims, max_tscale=10000):
22
+ assert dims % 2 == 0
23
+ pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
24
+ tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
25
+ scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
26
+ position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
27
+ positional_embedding = nn.Parameter(position, requires_grad=True)
28
+ return positional_embedding
29
+
30
+ def valid(default_value, *items):
31
+ for item in items:
32
+ if item is not None:
33
+ return item
34
+ return default_value
35
+
36
+ def dict_to(d, device, dtype=dtype):
37
+ return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
38
+ for k, v in d.items()}
39
+
40
+ def exists(v):
41
+ return v is not None
42
+
43
+ def default(v, b):
44
+ return v if exists(v) else b
45
+
46
+ class Conv1d(nn.Conv1d):
47
+ def _conv_forward(
48
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
49
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
50
+
51
+ class Conv2d(nn.Conv2d):
52
+ def _conv_forward(
53
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
54
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
55
+
56
+ class Linear(nn.Module):
57
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
58
+ super(Linear, self).__init__()
59
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
60
+ torch.nn.init.xavier_uniform_(self.linear.weight)
61
+ if bias:
62
+ torch.nn.init.zeros_(self.linear.bias)
63
+ def forward(self, x: Tensor) -> Tensor:
64
+ return self.linear(x)
65
+
66
+ class RMSNorm(nn.Module):
67
+ def __init__(self, dims: Union[int, Tensor, List, Tuple],
68
+ eps = 1e-8, elementwise_affine = True):
69
+ super(RMSNorm, self).__init__()
70
+ if isinstance(dims, int):
71
+ self.normalized_shape = (dims,)
72
+ else:
73
+ self.normalized_shape = tuple(dims)
74
+ self.eps = eps
75
+ self.elementwise_affine = elementwise_affine
76
+ if self.elementwise_affine:
77
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
78
+ torch.nn.init.ones_(self.weight)
79
+ else:
80
+ self.register_parameter("weight", None)
81
+ def forward(self, x):
82
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
83
+
84
+ def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
85
+ weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
86
+ eps: float = 1e-5) -> Tensor:
87
+ return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
88
+
89
+ def get_device():
90
+ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91
+
92
+ def get_dtype():
93
+ return torch.float32 if torch.cuda.is_available() else torch.float64
94
+
95
+ def l2norm(tensor):
96
+ dtype = tensor.dtype
97
+ normed = F.normalize(tensor, dim = -1)
98
+ return normed.type(dtype)
99
+
100
+ def get_activation(act: str) -> nn.Module:
101
+ act_map = {
102
+ "gelu": nn.GELU(),
103
+ "relu": nn.ReLU(),
104
+ "sigmoid": nn.Sigmoid(),
105
+ "tanh": nn.Tanh(),
106
+ "swish": nn.SiLU(),
107
+ "tanhshrink": nn.Tanhshrink(),
108
+ "softplus": nn.Softplus(),
109
+ "softshrink": nn.Softshrink(),
110
+ "leaky_relu": nn.LeakyReLU(),
111
+ "elu": nn.ELU()
112
+ }
113
+ return act_map.get(act, nn.GELU())
114
+
115
+ def there_is_a(val):
116
+ return val is not None
117
+
118
+ def exists(val):
119
+ return val is not None
120
+
121
+ def default(value, d):
122
+ return d if not exists(value) else value
123
+
124
+ def to(t):
125
+ return {'device': t.device, 'dtype': t.dtype}
126
+
127
+ PATH = 'E:/hf'
128
+ os.environ['HF_HOME'] = PATH
129
+ os.environ['HF_DATASETS_CACHE'] = PATH
130
+ os.environ['TORCH_HOME'] = PATH
131
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
132
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
133
+
134
  @dataclass
135
  class Dimensions:
136
  vocab: int
 
141
  layer: int
142
  act: str
143
 
144
+ def qkv_init(dims, head):
145
+ head_dim = dims // head
146
+ q = nn.Linear(dims, dims)
147
+ k = nn.Linear(dims, dims)
148
+ v = nn.Linear(dims, dims)
149
+ o = nn.Linear(dims, dims)
150
+ lna = nn.LayerNorm(dims)
151
+ lnb = nn.LayerNorm(dims)
152
+ lnc = nn.LayerNorm(head_dim)
153
+ lnd = nn.LayerNorm(head_dim)
154
+ return q, k, v, o, lna, lnb, lnc, lnd
155
+
156
+ def shape(dims, head, q, k, v):
157
+ batch_size = q.shape[0]
158
+ seq_len_q = q.shape[1]
159
+ seq_len_kv = k.shape[1]
160
+ head_dim = dims // head
161
+
162
+ q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
163
+ k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
164
+ v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
165
+ return q, k, v
166
+
167
  class rotary(nn.Module):
168
  def __init__(self, dims, head):
169
  super(rotary, self).__init__()
170
  self.dims = dims
171
  self.head = head
172
  self.head_dim = dims // head
173
+
174
+ self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
175
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
176
 
177
  def _compute_freqs_base(self):
178
  mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
179
  return 200 * mel_scale / 1000
180
 
181
+ def forward(self, x) -> Tensor:
182
+ freqs = (self.theta / 220.0) * self.freqs_base
183
+
184
+ pos = torch.arange(x.shape[2], device=device, dtype=dtype)
185
  freqs = pos[:, None] * freqs
186
  freqs=torch.polar(torch.ones_like(freqs), freqs)
187
 
 
194
  x1 = x1.view(orig_shape)
195
  return torch.cat([x1.type_as(x), x2], dim=-1)
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def calculate_attention(q, k, v, mask=None, temp=1.0):
198
  scaled_q = q
199
  if temp != 1.0 and temp > 0:
 
201
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
202
  return out
203
 
204
+ # def calculate_attention(q_norm, k_norm, v_iter, mask=None, temp=1.0):
205
+ # d_k = q_norm.size(-1)
206
+ # scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp)
207
+ # if mask is not None:
208
+ # scores = scores.masked_fill(mask == 0, float('-inf'))
209
+ # attention_weights = F.softmax(scores, dim=-1)
210
+ # output = torch.matmul(attention_weights, v_iter)
211
+ # return output
212
+
213
  class LocalOut(nn.Module):
214
  def __init__(self, dims: int, head: int):
215
  super().__init__()
216
+ self.head_dim = dims // head
217
+ self.dims = dims
218
+ self.q_hd = nn.Linear(self.head_dim, self.head_dim)
219
+ self.k_hd = nn.Linear(self.head_dim, self.head_dim)
220
+ self.v_hd = nn.Linear(self.head_dim, self.head_dim)
221
+ self.out = nn.Linear(self.head_dim, self.head_dim)
222
 
223
  def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
224
  batch, _, ctx, _ = attn_output.shape
225
+ return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
226
 
227
  class attentionb(nn.Module):
228
+ def __init__(self, dims: int, head: int, max_iter: int = 3,
229
+ threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0, use_win=False):
230
  super(attentionb, self).__init__()
231
+
 
232
  self.head = head
233
+ self.dims = dims
234
+ head_dim = dims // head
235
+
236
+ self.que = nn.Linear(dims, dims, bias=False)
237
+ self.kv = nn.Linear(dims, dims * 2, bias=False)
238
+ self.out = nn.Linear(dims, dims, bias=False)
239
+
240
+ self.lna = nn.LayerNorm(dims)
241
+ self.lnb = nn.LayerNorm(head_dim)
242
+ self.rope = rotary(dims, head)
243
+ self.use_win = use_win
244
+
245
  self.max_iter = max_iter
246
  self.threshold = nn.Parameter(torch.tensor(threshold))
247
  self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
248
  self.factor = nn.Parameter(torch.tensor(factor))
249
+ self.local = LocalOut(dims, head)
250
+
251
+ def update_win(self, win_size=None):
252
+ if win_size is not None:
253
+ self.win_size = win_size
254
+ return win_size
255
+ elif hasattr(self, 'win_size') and self.win_size is not None:
256
+ win_size = self.win_size
257
+ return win_size
258
+ return None
259
+
260
+ def _focus(self, x, xa = None, mask = None, use_win = False):
261
+
262
+ q = self.que(self.lna(x))
263
+ k, v = self.kv(self.lna(default(xa, x))).chunk(2, dim=-1)
264
+ q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
265
+ _, _, ctx, _ = q.shape
266
+ self.scale = q.shape[-1] ** -0.35
267
+ q = self.rope(q)
268
+ k = self.rope(k)
269
+
270
+ if use_win:
271
+ iteration = 0
272
+ temp = self.temp.item()
273
+ prev_out = torch.zeros_like(q)
274
+ attn_out = torch.zeros_like(q)
275
+ threshold = self.threshold.item()
276
+ factor = self.factor.item()
277
+ curq = q
278
+
279
+ while iteration < self.max_iter:
280
+ eff_span = min(curq.shape[1], k.shape[1])
281
+ if xa is not None:
282
+ eff_span = min(eff_span, xa.shape[1])
283
+ if eff_span == 0:
284
+ break
285
+
286
+ qiter = curq[:, :, :eff_span, :]
287
+ kiter = k[:, :, :eff_span, :]
288
+ viter = v[:, :, :eff_span, :]
289
+ q = self.local.q_hd(qiter)
290
+ k = self.local.k_hd(kiter)
291
+ v = self.local.v_hd(viter)
292
+
293
+ iter_mask = None
294
+ if mask is not None:
295
+ if mask.dim() == 4:
296
+ iter_mask = mask[:, :, :eff_span, :eff_span]
297
+ elif mask.dim() == 2:
298
+ iter_mask = mask[:eff_span, :eff_span]
299
+
300
+ attn_iter = calculate_attention(
301
+ self.lnb(q), self.lnb(k), v,
302
+ mask=iter_mask, temp=temp)
303
+
304
+ iter_out = torch.zeros_like(curq)
305
+ iter_out[:, :, :eff_span, :] = attn_iter
306
+ diff = torch.abs(iter_out - prev_out).mean()
307
+ dthresh = threshold + factor * diff
308
+ if diff < dthresh and iteration > 0:
309
+ attn_out = iter_out
310
+ break
311
+
312
+ prev_out = iter_out.clone()
313
+ curq = curq + iter_out
314
  attn_out = iter_out
315
+ iteration += 1
316
+ temp += 0.005
317
+ else:
318
+ attn_out = scaled_dot_product_attention(self.lnc(q), self.lnd(k), v, is_causal=mask is not None and ctx >1)
319
+ wv = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
320
+ qk=None
321
+ return self.out(wv), qk
322
 
323
+ def _slide_win_local(self, x, mask = None, use_win = True) -> Tensor:
 
324
 
325
+ win = self.update_win()
326
+ win_size = win if win is not None else 128
327
+ span_len = win_size + win_size // 10
328
 
329
+ _, ctx, _ = x.shape
330
  output = torch.zeros_like(x)
331
+ windows = (ctx + win_size - 1) // win_size
332
 
333
+ for i in range(windows):
334
  qstart = i * win_size
335
  qend = min(qstart + win_size, ctx)
336
  win_qlen = qend - qstart
 
349
  elif mask.dim() == 2:
350
  win_mask = mask[qstart:qend, kstart:kend]
351
 
352
+ attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask, use_win=True)
353
  output[:, qstart:qend, :] = attn_out
354
  return output
355
 
356
+ def forward(self, x, xa = None, mask = None, use_win: bool = False):
357
+ if use_win:
358
+ return self._slide_win_local(x, mask, use_win=True)
 
359
  else:
360
+ output, _ = self._focus(x, xa, mask, use_win=False)
361
  return output
362
 
363
  class attentiona(nn.Module):
364
+ def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
365
+ super().__init__()
366
+ self.head = head
367
  self.dims = dims
368
+ self.cross_talk = cross_talk
369
+
370
+ self.que = nn.Linear(dims, dims, bias=False)
371
+ self.kv = nn.Linear(dims, dims * 2, bias=False)
372
+ self.out = nn.Linear(dims, dims, bias=False)
373
+
374
+ self.ln = nn.LayerNorm(dims)
375
+ self.rope = rotary(dims, head)
376
+
377
+ self.x = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
378
+ self.xa = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
379
+
380
+ def forward(self, x, xa = None, mask = None):
381
+
382
+ q = self.que(self.ln(x))
383
+ k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
384
+ q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
385
+
386
+ scale = q.shape[-1] ** -0.5
387
+
388
+ q = self.rope(q)
389
+ k = self.rope(k)
390
+
391
+ qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
392
+
393
+ if there_is_a(mask):
394
+ i, j = qk.shape[-2:]
395
+ mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
396
+ qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
397
+
398
+ qk = torch.nn.functional.softmax(qk, dim=-1)
399
+ wv = einsum('b h k q, b h q d -> b h k d', qk, v)
400
+ wv = rearrange(wv, 'b h c d -> b c (h d)')
401
+ out = self.out(wv)
402
+
403
+ if self.cross_talk:
404
+ qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
405
+ qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
406
+ qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
407
+ qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
408
+ if there_is_a(mask):
409
+ i, j = qk.shape[-2:]
410
+ mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
411
+ qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
412
+ x = qk.softmax(dim = -1)
413
+ xa = qk.softmax(dim = -2)
414
+ x = self.x(x)
415
+ xa = self.xa(xa)
416
+ x = einsum('b h i j, b h j d -> b h i d', x, va)
417
+ xa = einsum('b h j i, b h j d -> b h i d', xa, v)
418
+ x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
419
+ out = self.out(x) # outxa = self.out(xa)
420
+
421
+ return out, qk
422
+
423
+ class attentiond(nn.Module):
424
+ def __init__(self, dims: int, head: int):
425
+ super().__init__()
426
  self.head = head
427
+ self.dims = dims
428
+
429
+ self.que = nn.Linear(dims, dims, bias=False)
430
+ self.kv = nn.Linear(dims, dims * 2, bias=False)
431
+ self.out = nn.Linear(dims, dims, bias=False)
432
+
433
+ self.ln = nn.LayerNorm(dims)
434
+ self.rope = rotary(dims, head)
435
+
436
+ self.x = nn.Conv2d(head, head, 1, bias = False)
437
+ self.xa = nn.Conv2d(head, head, 1, bias = False)
438
+
439
+ def forward(self, x, xa = None, mask = None):
440
+
441
+ qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
442
+ qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
443
+ qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
444
+ qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
445
+ if there_is_a(mask):
446
+ i, j = qk.shape[-2:]
447
+ mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
448
+ qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
449
+ x = qk.softmax(dim = -1)
450
+ xa = qk.softmax(dim = -2)
451
+ x = self.x(x)
452
+ xa = self.xa(xa)
453
+ x = einsum('b h i j, b h j d -> b h i d', x, va)
454
+ xa = einsum('b h j i, b h j d -> b h i d', xa, v)
455
+ x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
456
+ out = self.out(x)
457
+ outxa = self.out(xa)
458
+
459
+ return out, outxa, qk
460
+
461
+ class tgate(nn.Module):
462
+ def __init__(self, dims, num_types=4):
463
+ super().__init__()
464
+ self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
465
+ self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
466
+ def forward(self, x):
467
+ types = self.classifier(x)
468
+ gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
469
+ cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
470
+ return cgate
471
+
472
+ class residual(nn.Module):
473
  def __init__(self, dims: int, head: int, act: str = "silu"):
474
  super().__init__()
475
 
476
+ self.lna = nn.LayerNorm(dims, bias=False)
477
+ self.atta = attentiona(dims, head)
478
+ self.attb = attentiona(dims, head)
479
+ # self.attc = attentiona(dims, head, cross_talk=True)
480
+ self.attb = attentionb(dims, head, max_iter=1, use_win=True, temp=1.0)
481
+ self.tgate = tgate(dims, num_types=4)
482
+
483
+ self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
484
+
485
+ def forward(
486
+ self,
487
+ x: Tensor,
488
+ xa: Optional[Tensor] = None,
489
+ mask: Optional[Tensor] = None,
490
+ ):
491
+ x = x + self.atta(x, mask=mask)[0]
492
  if xa is not None:
493
+ x = x + self.attb(x, xa, mask=None)[0]
494
+ # x = x + self.attc(x, xa, mask=None)[0]
495
+ # x = x + self.attd(x, xa, mask=None)[0]
496
+ x = x + self.tgate(x)
497
  x = x + self.mlp(self.lna(x))
498
  return x
499
 
 
501
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
502
  super(processor, self).__init__()
503
 
504
+ self.ln = nn.LayerNorm(dims)
505
+ self.token = nn.Embedding(vocab, dims)
506
+ self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
507
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
 
508
 
509
  act_fn = get_activation(act)
510
+ self.encoder = nn.Sequential(
511
+ nn.Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
512
+ nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
513
+ nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
514
 
515
+ self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
516
+ self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer // 2)])
517
 
518
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
519
  self.register_buffer("mask", mask, persistent=False)
520
 
521
+ def forward(self, x, xa, sequential=False, modal=True, kv_cache=None) -> Tensor:
522
 
523
+ if xa.dim() == 2:
524
+ xa = xa.unsqueeze(0)
 
525
 
526
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
527
+ x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
528
+
529
+ xa = self.encoder(xa).permute(0, 2, 1)
530
+ xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
531
+
532
+ for block in chain(self.blocka or []):
533
+ xa = block(xa, mask=None)
534
+ x = block(x, mask=self.mask)
535
+ x = block(x, xa, mask=None)
536
+
537
+ for block in chain(self.blockm or []):
538
+ xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
539
+ x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
540
 
541
  x = nn.functional.dropout(x, p=0.001, training=self.training)
542
+ x = self.ln(x)
543
+ x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
544
+ return x
545
 
 
 
 
 
 
 
 
 
546
  class Model(nn.Module):
547
  def __init__(self, param: Dimensions):
548
  super().__init__()
 
556
  layer=param.layer,
557
  act=param.act)
558
 
559
+ self.best_loss = float('inf')
560
+ self.factor = nn.Parameter(torch.tensor(2), requires_grad=False)
561
+
562
+ def update(self, win_size):
563
+ for name, module in self.processor.named_modules():
564
+ if isinstance(module, (attentionb)):
565
+ module.update_win(win_size)
566
+
567
+ def adjust_window(self, loss, ctx):
568
+ self.win_size = ((ctx // self.param.head)) #based on loss but not on the graph itself which is the idea
569
+ if loss < self.best_loss:
570
+ win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
571
+ else:
572
+ win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
573
+ self.win_size = win_size
574
+ self.best_loss = loss
575
+ self.update(win_size)
576
+ return win_size
577
+
578
+ def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None):
579
+
580
  x = input_ids
581
+ xa = pitch if pitch is not None else spectrogram # xb = pitch_tokens
582
  logits = self.processor(x, xa)
583
  loss = None
584
  if labels is not None:
585
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
586
+ self.adjust_window(loss=loss.item(), ctx=xa.shape[2])
587
  return {"logits": logits, "loss": loss}
588
 
589
  def _init_weights(self, module):
 
591
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
592
  "Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
593
  for name, module in self.named_modules():
594
+ if isinstance(module, nn.RMSNorm):
595
  nn.init.ones_(module.weight)
596
  self.init_counts["RMSNorm"] += 1
597
+ if isinstance(module, nn.LayerNorm):
598
+ nn.init.ones_(module.weight)
599
+ self.init_counts["LayerNorm"] += 1
600
  elif isinstance(module, nn.Linear):
601
  if module.weight is not None:
602
  nn.init.xavier_uniform_(module.weight)
603
  if module.bias is not None:
604
  nn.init.zeros_(module.bias)
605
  self.init_counts["Linear"] += 1
606
+ elif isinstance(module, nn.Conv1d):
607
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
608
  if module.bias is not None:
609
  nn.init.zeros_(module.bias)
610
  self.init_counts["Conv1d"] += 1
611
+ elif isinstance(module, nn.Conv2d):
612
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
613
  if module.bias is not None:
614
  nn.init.zeros_(module.bias)
615
  self.init_counts["Conv2d"] += 1
616
+ elif isinstance(module, residual):
617
  self.init_counts["Residual"] += 1
618
  elif isinstance(module, processor):
619
  self.init_counts["processor"] += 1
 
625
  for module_type, count in self.init_counts.items():
626
  if count > 0:
627
  print(f"{module_type}: {count}")
628
+
629
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
630
+ cache = {**cache} if cache is not None else {}
631
+ hooks = []
632
+ def save_to_cache(module, _, output):
633
+ if module not in cache or output.shape[1] > self.param.ctx:
634
+ cache[module] = output
635
+ else:
636
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
637
+ return cache[module]
638
+
639
+ def install_hooks(layer: nn.Module):
640
+ if isinstance(layer, attentiona):
641
+ hooks.append(layer.k.register_forward_hook(save_to_cache))
642
+ hooks.append(layer.v.register_forward_hook(save_to_cache))
643
+ self.processor.apply(install_hooks)
644
+ return cache, hooks