Update model_simple.py
Browse files- model_simple.py +468 -168
model_simple.py
CHANGED
@@ -3,18 +3,134 @@ import warnings
|
|
3 |
import logging
|
4 |
from itertools import chain
|
5 |
import torch
|
6 |
-
|
7 |
-
from
|
|
|
|
|
8 |
import numpy as np
|
9 |
-
from datetime import datetime
|
10 |
from dataclasses import dataclass
|
|
|
11 |
from torch.nn.functional import scaled_dot_product_attention
|
12 |
-
from
|
|
|
13 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
dtype = torch.float32
|
15 |
warnings.filterwarnings("ignore")
|
16 |
logging.basicConfig(level=logging.ERROR)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
@dataclass
|
19 |
class Dimensions:
|
20 |
vocab: int
|
@@ -25,22 +141,47 @@ class Dimensions:
|
|
25 |
layer: int
|
26 |
act: str
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
class rotary(nn.Module):
|
29 |
def __init__(self, dims, head):
|
30 |
super(rotary, self).__init__()
|
31 |
self.dims = dims
|
32 |
self.head = head
|
33 |
self.head_dim = dims // head
|
34 |
-
|
|
|
35 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
36 |
|
37 |
def _compute_freqs_base(self):
|
38 |
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
|
39 |
return 200 * mel_scale / 1000
|
40 |
|
41 |
-
def forward(self, x
|
42 |
-
freqs = (self.theta / 220.0) * self.freqs_base
|
43 |
-
|
|
|
44 |
freqs = pos[:, None] * freqs
|
45 |
freqs=torch.polar(torch.ones_like(freqs), freqs)
|
46 |
|
@@ -53,28 +194,6 @@ class rotary(nn.Module):
|
|
53 |
x1 = x1.view(orig_shape)
|
54 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
55 |
|
56 |
-
def shape(dims, head, q, k, v):
|
57 |
-
head_dim = dims // head
|
58 |
-
scale = head_dim ** -0.25
|
59 |
-
q = q * scale
|
60 |
-
k = k * scale
|
61 |
-
v = v
|
62 |
-
def _shape(tensor):
|
63 |
-
return tensor.view(*tensor.shape[:2], head, -1).permute(0, 2, 1, 3).contiguous()
|
64 |
-
return _shape(q), _shape(k), _shape(v)
|
65 |
-
|
66 |
-
def qkv_init(dims: int, head: int):
|
67 |
-
head_dim = dims // head
|
68 |
-
q = nn.Linear(dims, dims)
|
69 |
-
k = nn.Linear(dims, dims, bias=False)
|
70 |
-
v = nn.Linear(dims, dims)
|
71 |
-
o = nn.Linear(dims, dims)
|
72 |
-
lna = nn.LayerNorm(dims, bias=False)
|
73 |
-
lnb = nn.LayerNorm(dims, bias=False)
|
74 |
-
lnc = nn.LayerNorm(head_dim, bias=False)
|
75 |
-
lnd = nn.LayerNorm(head_dim, bias=False)
|
76 |
-
return q, k, v, o, lna, lnb, lnc, lnd
|
77 |
-
|
78 |
def calculate_attention(q, k, v, mask=None, temp=1.0):
|
79 |
scaled_q = q
|
80 |
if temp != 1.0 and temp > 0:
|
@@ -82,94 +201,136 @@ def calculate_attention(q, k, v, mask=None, temp=1.0):
|
|
82 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
83 |
return out
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
class LocalOut(nn.Module):
|
86 |
def __init__(self, dims: int, head: int):
|
87 |
super().__init__()
|
88 |
-
head_dim = dims // head
|
89 |
-
self.
|
90 |
-
self.
|
91 |
-
self.
|
92 |
-
self.
|
|
|
93 |
|
94 |
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
|
95 |
batch, _, ctx, _ = attn_output.shape
|
96 |
-
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
97 |
|
98 |
class attentionb(nn.Module):
|
99 |
-
def __init__(self, dims: int, head: int, max_iter: int = 3,
|
|
|
100 |
super(attentionb, self).__init__()
|
101 |
-
|
102 |
-
self.dims = dims
|
103 |
self.head = head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
self.max_iter = max_iter
|
105 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
106 |
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
107 |
self.factor = nn.Parameter(torch.tensor(factor))
|
108 |
-
self.
|
109 |
-
|
110 |
-
def
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
if
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
attn_out = iter_out
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
|
163 |
-
|
164 |
-
return self.o(output), None
|
165 |
|
166 |
-
|
|
|
|
|
167 |
|
168 |
-
|
169 |
output = torch.zeros_like(x)
|
170 |
-
|
171 |
|
172 |
-
for i in range(
|
173 |
qstart = i * win_size
|
174 |
qend = min(qstart + win_size, ctx)
|
175 |
win_qlen = qend - qstart
|
@@ -188,50 +349,151 @@ class attentionb(nn.Module):
|
|
188 |
elif mask.dim() == 2:
|
189 |
win_mask = mask[qstart:qend, kstart:kend]
|
190 |
|
191 |
-
attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask)
|
192 |
output[:, qstart:qend, :] = attn_out
|
193 |
return output
|
194 |
|
195 |
-
def forward(self, x
|
196 |
-
|
197 |
-
|
198 |
-
return self._slide_win_local(x, win_size, span_len, mask)
|
199 |
else:
|
200 |
-
output, _ = self._focus(x, xa, mask)
|
201 |
return output
|
202 |
|
203 |
class attentiona(nn.Module):
|
204 |
-
def __init__(self, dims: int, head: int):
|
205 |
-
super(
|
206 |
-
self.
|
207 |
self.dims = dims
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
self.head = head
|
209 |
-
self.
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
def __init__(self, dims: int, head: int, act: str = "silu"):
|
223 |
super().__init__()
|
224 |
|
225 |
-
self.lna = nn.LayerNorm(dims, bias=False)
|
226 |
-
self.
|
227 |
-
self.
|
228 |
-
self.
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
if xa is not None:
|
233 |
-
x = x + self.
|
234 |
-
x = x + self.
|
|
|
|
|
235 |
x = x + self.mlp(self.lna(x))
|
236 |
return x
|
237 |
|
@@ -239,50 +501,48 @@ class processor(nn.Module):
|
|
239 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
240 |
super(processor, self).__init__()
|
241 |
|
242 |
-
self.
|
243 |
-
self.
|
244 |
-
self.
|
245 |
-
self.token_emb = nn.Embedding(vocab, dims)
|
246 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
247 |
-
self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
248 |
|
249 |
act_fn = get_activation(act)
|
250 |
-
self.
|
251 |
-
Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
252 |
-
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
253 |
-
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
254 |
|
255 |
-
self.
|
|
|
256 |
|
257 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
258 |
self.register_buffer("mask", mask, persistent=False)
|
259 |
|
260 |
-
def forward(self, x, xa, sequential=False, modal=
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
265 |
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
274 |
-
x = self.
|
275 |
-
x = x @ torch.transpose(self.
|
276 |
-
return x
|
277 |
|
278 |
-
def init_weights(self):
|
279 |
-
print("Initializing model weights...")
|
280 |
-
self.apply(self._init_weights)
|
281 |
-
print("Initialization summary:")
|
282 |
-
for module_type, count in self.init_counts.items():
|
283 |
-
if count > 0:
|
284 |
-
print(f"{module_type}: {count}")
|
285 |
-
|
286 |
class Model(nn.Module):
|
287 |
def __init__(self, param: Dimensions):
|
288 |
super().__init__()
|
@@ -296,14 +556,34 @@ class Model(nn.Module):
|
|
296 |
layer=param.layer,
|
297 |
act=param.act)
|
298 |
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
x = input_ids
|
302 |
-
xa = pitch if pitch is not None else
|
303 |
logits = self.processor(x, xa)
|
304 |
loss = None
|
305 |
if labels is not None:
|
306 |
-
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
|
|
307 |
return {"logits": logits, "loss": loss}
|
308 |
|
309 |
def _init_weights(self, module):
|
@@ -311,26 +591,29 @@ class Model(nn.Module):
|
|
311 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
312 |
"Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
|
313 |
for name, module in self.named_modules():
|
314 |
-
if isinstance(module, RMSNorm):
|
315 |
nn.init.ones_(module.weight)
|
316 |
self.init_counts["RMSNorm"] += 1
|
|
|
|
|
|
|
317 |
elif isinstance(module, nn.Linear):
|
318 |
if module.weight is not None:
|
319 |
nn.init.xavier_uniform_(module.weight)
|
320 |
if module.bias is not None:
|
321 |
nn.init.zeros_(module.bias)
|
322 |
self.init_counts["Linear"] += 1
|
323 |
-
elif isinstance(module, Conv1d):
|
324 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
325 |
if module.bias is not None:
|
326 |
nn.init.zeros_(module.bias)
|
327 |
self.init_counts["Conv1d"] += 1
|
328 |
-
elif isinstance(module, Conv2d):
|
329 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
330 |
if module.bias is not None:
|
331 |
nn.init.zeros_(module.bias)
|
332 |
self.init_counts["Conv2d"] += 1
|
333 |
-
elif isinstance(module,
|
334 |
self.init_counts["Residual"] += 1
|
335 |
elif isinstance(module, processor):
|
336 |
self.init_counts["processor"] += 1
|
@@ -342,3 +625,20 @@ class Model(nn.Module):
|
|
342 |
for module_type, count in self.init_counts.items():
|
343 |
if count > 0:
|
344 |
print(f"{module_type}: {count}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import logging
|
4 |
from itertools import chain
|
5 |
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn, Tensor, einsum
|
8 |
+
from typing import Optional, List, Tuple, Union
|
9 |
+
|
10 |
import numpy as np
|
|
|
11 |
from dataclasses import dataclass
|
12 |
+
|
13 |
from torch.nn.functional import scaled_dot_product_attention
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
17 |
dtype = torch.float32
|
18 |
warnings.filterwarnings("ignore")
|
19 |
logging.basicConfig(level=logging.ERROR)
|
20 |
|
21 |
+
def sinusoids(ctx, dims, max_tscale=10000):
|
22 |
+
assert dims % 2 == 0
|
23 |
+
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
24 |
+
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
|
25 |
+
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
|
26 |
+
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
|
27 |
+
positional_embedding = nn.Parameter(position, requires_grad=True)
|
28 |
+
return positional_embedding
|
29 |
+
|
30 |
+
def valid(default_value, *items):
|
31 |
+
for item in items:
|
32 |
+
if item is not None:
|
33 |
+
return item
|
34 |
+
return default_value
|
35 |
+
|
36 |
+
def dict_to(d, device, dtype=dtype):
|
37 |
+
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
38 |
+
for k, v in d.items()}
|
39 |
+
|
40 |
+
def exists(v):
|
41 |
+
return v is not None
|
42 |
+
|
43 |
+
def default(v, b):
|
44 |
+
return v if exists(v) else b
|
45 |
+
|
46 |
+
class Conv1d(nn.Conv1d):
|
47 |
+
def _conv_forward(
|
48 |
+
self, x: Tensor, weight: Tensor, bias) -> Tensor:
|
49 |
+
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
|
50 |
+
|
51 |
+
class Conv2d(nn.Conv2d):
|
52 |
+
def _conv_forward(
|
53 |
+
self, x: Tensor, weight: Tensor, bias) -> Tensor:
|
54 |
+
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
|
55 |
+
|
56 |
+
class Linear(nn.Module):
|
57 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
58 |
+
super(Linear, self).__init__()
|
59 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
60 |
+
torch.nn.init.xavier_uniform_(self.linear.weight)
|
61 |
+
if bias:
|
62 |
+
torch.nn.init.zeros_(self.linear.bias)
|
63 |
+
def forward(self, x: Tensor) -> Tensor:
|
64 |
+
return self.linear(x)
|
65 |
+
|
66 |
+
class RMSNorm(nn.Module):
|
67 |
+
def __init__(self, dims: Union[int, Tensor, List, Tuple],
|
68 |
+
eps = 1e-8, elementwise_affine = True):
|
69 |
+
super(RMSNorm, self).__init__()
|
70 |
+
if isinstance(dims, int):
|
71 |
+
self.normalized_shape = (dims,)
|
72 |
+
else:
|
73 |
+
self.normalized_shape = tuple(dims)
|
74 |
+
self.eps = eps
|
75 |
+
self.elementwise_affine = elementwise_affine
|
76 |
+
if self.elementwise_affine:
|
77 |
+
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
|
78 |
+
torch.nn.init.ones_(self.weight)
|
79 |
+
else:
|
80 |
+
self.register_parameter("weight", None)
|
81 |
+
def forward(self, x):
|
82 |
+
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
|
83 |
+
|
84 |
+
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
85 |
+
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
86 |
+
eps: float = 1e-5) -> Tensor:
|
87 |
+
return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
|
88 |
+
|
89 |
+
def get_device():
|
90 |
+
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
91 |
+
|
92 |
+
def get_dtype():
|
93 |
+
return torch.float32 if torch.cuda.is_available() else torch.float64
|
94 |
+
|
95 |
+
def l2norm(tensor):
|
96 |
+
dtype = tensor.dtype
|
97 |
+
normed = F.normalize(tensor, dim = -1)
|
98 |
+
return normed.type(dtype)
|
99 |
+
|
100 |
+
def get_activation(act: str) -> nn.Module:
|
101 |
+
act_map = {
|
102 |
+
"gelu": nn.GELU(),
|
103 |
+
"relu": nn.ReLU(),
|
104 |
+
"sigmoid": nn.Sigmoid(),
|
105 |
+
"tanh": nn.Tanh(),
|
106 |
+
"swish": nn.SiLU(),
|
107 |
+
"tanhshrink": nn.Tanhshrink(),
|
108 |
+
"softplus": nn.Softplus(),
|
109 |
+
"softshrink": nn.Softshrink(),
|
110 |
+
"leaky_relu": nn.LeakyReLU(),
|
111 |
+
"elu": nn.ELU()
|
112 |
+
}
|
113 |
+
return act_map.get(act, nn.GELU())
|
114 |
+
|
115 |
+
def there_is_a(val):
|
116 |
+
return val is not None
|
117 |
+
|
118 |
+
def exists(val):
|
119 |
+
return val is not None
|
120 |
+
|
121 |
+
def default(value, d):
|
122 |
+
return d if not exists(value) else value
|
123 |
+
|
124 |
+
def to(t):
|
125 |
+
return {'device': t.device, 'dtype': t.dtype}
|
126 |
+
|
127 |
+
PATH = 'E:/hf'
|
128 |
+
os.environ['HF_HOME'] = PATH
|
129 |
+
os.environ['HF_DATASETS_CACHE'] = PATH
|
130 |
+
os.environ['TORCH_HOME'] = PATH
|
131 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
132 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
133 |
+
|
134 |
@dataclass
|
135 |
class Dimensions:
|
136 |
vocab: int
|
|
|
141 |
layer: int
|
142 |
act: str
|
143 |
|
144 |
+
def qkv_init(dims, head):
|
145 |
+
head_dim = dims // head
|
146 |
+
q = nn.Linear(dims, dims)
|
147 |
+
k = nn.Linear(dims, dims)
|
148 |
+
v = nn.Linear(dims, dims)
|
149 |
+
o = nn.Linear(dims, dims)
|
150 |
+
lna = nn.LayerNorm(dims)
|
151 |
+
lnb = nn.LayerNorm(dims)
|
152 |
+
lnc = nn.LayerNorm(head_dim)
|
153 |
+
lnd = nn.LayerNorm(head_dim)
|
154 |
+
return q, k, v, o, lna, lnb, lnc, lnd
|
155 |
+
|
156 |
+
def shape(dims, head, q, k, v):
|
157 |
+
batch_size = q.shape[0]
|
158 |
+
seq_len_q = q.shape[1]
|
159 |
+
seq_len_kv = k.shape[1]
|
160 |
+
head_dim = dims // head
|
161 |
+
|
162 |
+
q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
|
163 |
+
k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
|
164 |
+
v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
|
165 |
+
return q, k, v
|
166 |
+
|
167 |
class rotary(nn.Module):
|
168 |
def __init__(self, dims, head):
|
169 |
super(rotary, self).__init__()
|
170 |
self.dims = dims
|
171 |
self.head = head
|
172 |
self.head_dim = dims // head
|
173 |
+
|
174 |
+
self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
|
175 |
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
|
176 |
|
177 |
def _compute_freqs_base(self):
|
178 |
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
|
179 |
return 200 * mel_scale / 1000
|
180 |
|
181 |
+
def forward(self, x) -> Tensor:
|
182 |
+
freqs = (self.theta / 220.0) * self.freqs_base
|
183 |
+
|
184 |
+
pos = torch.arange(x.shape[2], device=device, dtype=dtype)
|
185 |
freqs = pos[:, None] * freqs
|
186 |
freqs=torch.polar(torch.ones_like(freqs), freqs)
|
187 |
|
|
|
194 |
x1 = x1.view(orig_shape)
|
195 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
def calculate_attention(q, k, v, mask=None, temp=1.0):
|
198 |
scaled_q = q
|
199 |
if temp != 1.0 and temp > 0:
|
|
|
201 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
202 |
return out
|
203 |
|
204 |
+
# def calculate_attention(q_norm, k_norm, v_iter, mask=None, temp=1.0):
|
205 |
+
# d_k = q_norm.size(-1)
|
206 |
+
# scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp)
|
207 |
+
# if mask is not None:
|
208 |
+
# scores = scores.masked_fill(mask == 0, float('-inf'))
|
209 |
+
# attention_weights = F.softmax(scores, dim=-1)
|
210 |
+
# output = torch.matmul(attention_weights, v_iter)
|
211 |
+
# return output
|
212 |
+
|
213 |
class LocalOut(nn.Module):
|
214 |
def __init__(self, dims: int, head: int):
|
215 |
super().__init__()
|
216 |
+
self.head_dim = dims // head
|
217 |
+
self.dims = dims
|
218 |
+
self.q_hd = nn.Linear(self.head_dim, self.head_dim)
|
219 |
+
self.k_hd = nn.Linear(self.head_dim, self.head_dim)
|
220 |
+
self.v_hd = nn.Linear(self.head_dim, self.head_dim)
|
221 |
+
self.out = nn.Linear(self.head_dim, self.head_dim)
|
222 |
|
223 |
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
|
224 |
batch, _, ctx, _ = attn_output.shape
|
225 |
+
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
226 |
|
227 |
class attentionb(nn.Module):
|
228 |
+
def __init__(self, dims: int, head: int, max_iter: int = 3,
|
229 |
+
threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0, use_win=False):
|
230 |
super(attentionb, self).__init__()
|
231 |
+
|
|
|
232 |
self.head = head
|
233 |
+
self.dims = dims
|
234 |
+
head_dim = dims // head
|
235 |
+
|
236 |
+
self.que = nn.Linear(dims, dims, bias=False)
|
237 |
+
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
238 |
+
self.out = nn.Linear(dims, dims, bias=False)
|
239 |
+
|
240 |
+
self.lna = nn.LayerNorm(dims)
|
241 |
+
self.lnb = nn.LayerNorm(head_dim)
|
242 |
+
self.rope = rotary(dims, head)
|
243 |
+
self.use_win = use_win
|
244 |
+
|
245 |
self.max_iter = max_iter
|
246 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
247 |
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
248 |
self.factor = nn.Parameter(torch.tensor(factor))
|
249 |
+
self.local = LocalOut(dims, head)
|
250 |
+
|
251 |
+
def update_win(self, win_size=None):
|
252 |
+
if win_size is not None:
|
253 |
+
self.win_size = win_size
|
254 |
+
return win_size
|
255 |
+
elif hasattr(self, 'win_size') and self.win_size is not None:
|
256 |
+
win_size = self.win_size
|
257 |
+
return win_size
|
258 |
+
return None
|
259 |
+
|
260 |
+
def _focus(self, x, xa = None, mask = None, use_win = False):
|
261 |
+
|
262 |
+
q = self.que(self.lna(x))
|
263 |
+
k, v = self.kv(self.lna(default(xa, x))).chunk(2, dim=-1)
|
264 |
+
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
265 |
+
_, _, ctx, _ = q.shape
|
266 |
+
self.scale = q.shape[-1] ** -0.35
|
267 |
+
q = self.rope(q)
|
268 |
+
k = self.rope(k)
|
269 |
+
|
270 |
+
if use_win:
|
271 |
+
iteration = 0
|
272 |
+
temp = self.temp.item()
|
273 |
+
prev_out = torch.zeros_like(q)
|
274 |
+
attn_out = torch.zeros_like(q)
|
275 |
+
threshold = self.threshold.item()
|
276 |
+
factor = self.factor.item()
|
277 |
+
curq = q
|
278 |
+
|
279 |
+
while iteration < self.max_iter:
|
280 |
+
eff_span = min(curq.shape[1], k.shape[1])
|
281 |
+
if xa is not None:
|
282 |
+
eff_span = min(eff_span, xa.shape[1])
|
283 |
+
if eff_span == 0:
|
284 |
+
break
|
285 |
+
|
286 |
+
qiter = curq[:, :, :eff_span, :]
|
287 |
+
kiter = k[:, :, :eff_span, :]
|
288 |
+
viter = v[:, :, :eff_span, :]
|
289 |
+
q = self.local.q_hd(qiter)
|
290 |
+
k = self.local.k_hd(kiter)
|
291 |
+
v = self.local.v_hd(viter)
|
292 |
+
|
293 |
+
iter_mask = None
|
294 |
+
if mask is not None:
|
295 |
+
if mask.dim() == 4:
|
296 |
+
iter_mask = mask[:, :, :eff_span, :eff_span]
|
297 |
+
elif mask.dim() == 2:
|
298 |
+
iter_mask = mask[:eff_span, :eff_span]
|
299 |
+
|
300 |
+
attn_iter = calculate_attention(
|
301 |
+
self.lnb(q), self.lnb(k), v,
|
302 |
+
mask=iter_mask, temp=temp)
|
303 |
+
|
304 |
+
iter_out = torch.zeros_like(curq)
|
305 |
+
iter_out[:, :, :eff_span, :] = attn_iter
|
306 |
+
diff = torch.abs(iter_out - prev_out).mean()
|
307 |
+
dthresh = threshold + factor * diff
|
308 |
+
if diff < dthresh and iteration > 0:
|
309 |
+
attn_out = iter_out
|
310 |
+
break
|
311 |
+
|
312 |
+
prev_out = iter_out.clone()
|
313 |
+
curq = curq + iter_out
|
314 |
attn_out = iter_out
|
315 |
+
iteration += 1
|
316 |
+
temp += 0.005
|
317 |
+
else:
|
318 |
+
attn_out = scaled_dot_product_attention(self.lnc(q), self.lnd(k), v, is_causal=mask is not None and ctx >1)
|
319 |
+
wv = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
320 |
+
qk=None
|
321 |
+
return self.out(wv), qk
|
322 |
|
323 |
+
def _slide_win_local(self, x, mask = None, use_win = True) -> Tensor:
|
|
|
324 |
|
325 |
+
win = self.update_win()
|
326 |
+
win_size = win if win is not None else 128
|
327 |
+
span_len = win_size + win_size // 10
|
328 |
|
329 |
+
_, ctx, _ = x.shape
|
330 |
output = torch.zeros_like(x)
|
331 |
+
windows = (ctx + win_size - 1) // win_size
|
332 |
|
333 |
+
for i in range(windows):
|
334 |
qstart = i * win_size
|
335 |
qend = min(qstart + win_size, ctx)
|
336 |
win_qlen = qend - qstart
|
|
|
349 |
elif mask.dim() == 2:
|
350 |
win_mask = mask[qstart:qend, kstart:kend]
|
351 |
|
352 |
+
attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask, use_win=True)
|
353 |
output[:, qstart:qend, :] = attn_out
|
354 |
return output
|
355 |
|
356 |
+
def forward(self, x, xa = None, mask = None, use_win: bool = False):
|
357 |
+
if use_win:
|
358 |
+
return self._slide_win_local(x, mask, use_win=True)
|
|
|
359 |
else:
|
360 |
+
output, _ = self._focus(x, xa, mask, use_win=False)
|
361 |
return output
|
362 |
|
363 |
class attentiona(nn.Module):
|
364 |
+
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
|
365 |
+
super().__init__()
|
366 |
+
self.head = head
|
367 |
self.dims = dims
|
368 |
+
self.cross_talk = cross_talk
|
369 |
+
|
370 |
+
self.que = nn.Linear(dims, dims, bias=False)
|
371 |
+
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
372 |
+
self.out = nn.Linear(dims, dims, bias=False)
|
373 |
+
|
374 |
+
self.ln = nn.LayerNorm(dims)
|
375 |
+
self.rope = rotary(dims, head)
|
376 |
+
|
377 |
+
self.x = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
|
378 |
+
self.xa = nn.Conv2d(head, head, 1, bias = False) if cross_talk else None
|
379 |
+
|
380 |
+
def forward(self, x, xa = None, mask = None):
|
381 |
+
|
382 |
+
q = self.que(self.ln(x))
|
383 |
+
k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
384 |
+
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
385 |
+
|
386 |
+
scale = q.shape[-1] ** -0.5
|
387 |
+
|
388 |
+
q = self.rope(q)
|
389 |
+
k = self.rope(k)
|
390 |
+
|
391 |
+
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
|
392 |
+
|
393 |
+
if there_is_a(mask):
|
394 |
+
i, j = qk.shape[-2:]
|
395 |
+
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
|
396 |
+
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
|
397 |
+
|
398 |
+
qk = torch.nn.functional.softmax(qk, dim=-1)
|
399 |
+
wv = einsum('b h k q, b h q d -> b h k d', qk, v)
|
400 |
+
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
401 |
+
out = self.out(wv)
|
402 |
+
|
403 |
+
if self.cross_talk:
|
404 |
+
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
|
405 |
+
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
406 |
+
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
|
407 |
+
qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
|
408 |
+
if there_is_a(mask):
|
409 |
+
i, j = qk.shape[-2:]
|
410 |
+
mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
|
411 |
+
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
|
412 |
+
x = qk.softmax(dim = -1)
|
413 |
+
xa = qk.softmax(dim = -2)
|
414 |
+
x = self.x(x)
|
415 |
+
xa = self.xa(xa)
|
416 |
+
x = einsum('b h i j, b h j d -> b h i d', x, va)
|
417 |
+
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
|
418 |
+
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
|
419 |
+
out = self.out(x) # outxa = self.out(xa)
|
420 |
+
|
421 |
+
return out, qk
|
422 |
+
|
423 |
+
class attentiond(nn.Module):
|
424 |
+
def __init__(self, dims: int, head: int):
|
425 |
+
super().__init__()
|
426 |
self.head = head
|
427 |
+
self.dims = dims
|
428 |
+
|
429 |
+
self.que = nn.Linear(dims, dims, bias=False)
|
430 |
+
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
431 |
+
self.out = nn.Linear(dims, dims, bias=False)
|
432 |
+
|
433 |
+
self.ln = nn.LayerNorm(dims)
|
434 |
+
self.rope = rotary(dims, head)
|
435 |
+
|
436 |
+
self.x = nn.Conv2d(head, head, 1, bias = False)
|
437 |
+
self.xa = nn.Conv2d(head, head, 1, bias = False)
|
438 |
+
|
439 |
+
def forward(self, x, xa = None, mask = None):
|
440 |
+
|
441 |
+
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
|
442 |
+
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
443 |
+
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
|
444 |
+
qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
|
445 |
+
if there_is_a(mask):
|
446 |
+
i, j = qk.shape[-2:]
|
447 |
+
mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
|
448 |
+
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
|
449 |
+
x = qk.softmax(dim = -1)
|
450 |
+
xa = qk.softmax(dim = -2)
|
451 |
+
x = self.x(x)
|
452 |
+
xa = self.xa(xa)
|
453 |
+
x = einsum('b h i j, b h j d -> b h i d', x, va)
|
454 |
+
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
|
455 |
+
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
|
456 |
+
out = self.out(x)
|
457 |
+
outxa = self.out(xa)
|
458 |
+
|
459 |
+
return out, outxa, qk
|
460 |
+
|
461 |
+
class tgate(nn.Module):
|
462 |
+
def __init__(self, dims, num_types=4):
|
463 |
+
super().__init__()
|
464 |
+
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
|
465 |
+
self.classifier = nn.Sequential(nn.Linear(dims, num_types), nn.Softmax(dim=-1))
|
466 |
+
def forward(self, x):
|
467 |
+
types = self.classifier(x)
|
468 |
+
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
469 |
+
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
|
470 |
+
return cgate
|
471 |
+
|
472 |
+
class residual(nn.Module):
|
473 |
def __init__(self, dims: int, head: int, act: str = "silu"):
|
474 |
super().__init__()
|
475 |
|
476 |
+
self.lna = nn.LayerNorm(dims, bias=False)
|
477 |
+
self.atta = attentiona(dims, head)
|
478 |
+
self.attb = attentiona(dims, head)
|
479 |
+
# self.attc = attentiona(dims, head, cross_talk=True)
|
480 |
+
self.attb = attentionb(dims, head, max_iter=1, use_win=True, temp=1.0)
|
481 |
+
self.tgate = tgate(dims, num_types=4)
|
482 |
+
|
483 |
+
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
484 |
+
|
485 |
+
def forward(
|
486 |
+
self,
|
487 |
+
x: Tensor,
|
488 |
+
xa: Optional[Tensor] = None,
|
489 |
+
mask: Optional[Tensor] = None,
|
490 |
+
):
|
491 |
+
x = x + self.atta(x, mask=mask)[0]
|
492 |
if xa is not None:
|
493 |
+
x = x + self.attb(x, xa, mask=None)[0]
|
494 |
+
# x = x + self.attc(x, xa, mask=None)[0]
|
495 |
+
# x = x + self.attd(x, xa, mask=None)[0]
|
496 |
+
x = x + self.tgate(x)
|
497 |
x = x + self.mlp(self.lna(x))
|
498 |
return x
|
499 |
|
|
|
501 |
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
|
502 |
super(processor, self).__init__()
|
503 |
|
504 |
+
self.ln = nn.LayerNorm(dims)
|
505 |
+
self.token = nn.Embedding(vocab, dims)
|
506 |
+
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
|
507 |
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
|
|
|
508 |
|
509 |
act_fn = get_activation(act)
|
510 |
+
self.encoder = nn.Sequential(
|
511 |
+
nn.Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
512 |
+
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
513 |
+
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
514 |
|
515 |
+
self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
|
516 |
+
self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer // 2)])
|
517 |
|
518 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
519 |
self.register_buffer("mask", mask, persistent=False)
|
520 |
|
521 |
+
def forward(self, x, xa, sequential=False, modal=True, kv_cache=None) -> Tensor:
|
522 |
|
523 |
+
if xa.dim() == 2:
|
524 |
+
xa = xa.unsqueeze(0)
|
|
|
525 |
|
526 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
527 |
+
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
|
528 |
+
|
529 |
+
xa = self.encoder(xa).permute(0, 2, 1)
|
530 |
+
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
531 |
+
|
532 |
+
for block in chain(self.blocka or []):
|
533 |
+
xa = block(xa, mask=None)
|
534 |
+
x = block(x, mask=self.mask)
|
535 |
+
x = block(x, xa, mask=None)
|
536 |
+
|
537 |
+
for block in chain(self.blockm or []):
|
538 |
+
xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
|
539 |
+
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
|
540 |
|
541 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
542 |
+
x = self.ln(x)
|
543 |
+
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
544 |
+
return x
|
545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
class Model(nn.Module):
|
547 |
def __init__(self, param: Dimensions):
|
548 |
super().__init__()
|
|
|
556 |
layer=param.layer,
|
557 |
act=param.act)
|
558 |
|
559 |
+
self.best_loss = float('inf')
|
560 |
+
self.factor = nn.Parameter(torch.tensor(2), requires_grad=False)
|
561 |
+
|
562 |
+
def update(self, win_size):
|
563 |
+
for name, module in self.processor.named_modules():
|
564 |
+
if isinstance(module, (attentionb)):
|
565 |
+
module.update_win(win_size)
|
566 |
+
|
567 |
+
def adjust_window(self, loss, ctx):
|
568 |
+
self.win_size = ((ctx // self.param.head)) #based on loss but not on the graph itself which is the idea
|
569 |
+
if loss < self.best_loss:
|
570 |
+
win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
|
571 |
+
else:
|
572 |
+
win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
|
573 |
+
self.win_size = win_size
|
574 |
+
self.best_loss = loss
|
575 |
+
self.update(win_size)
|
576 |
+
return win_size
|
577 |
+
|
578 |
+
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None):
|
579 |
+
|
580 |
x = input_ids
|
581 |
+
xa = pitch if pitch is not None else spectrogram # xb = pitch_tokens
|
582 |
logits = self.processor(x, xa)
|
583 |
loss = None
|
584 |
if labels is not None:
|
585 |
+
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
586 |
+
self.adjust_window(loss=loss.item(), ctx=xa.shape[2])
|
587 |
return {"logits": logits, "loss": loss}
|
588 |
|
589 |
def _init_weights(self, module):
|
|
|
591 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
592 |
"Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
|
593 |
for name, module in self.named_modules():
|
594 |
+
if isinstance(module, nn.RMSNorm):
|
595 |
nn.init.ones_(module.weight)
|
596 |
self.init_counts["RMSNorm"] += 1
|
597 |
+
if isinstance(module, nn.LayerNorm):
|
598 |
+
nn.init.ones_(module.weight)
|
599 |
+
self.init_counts["LayerNorm"] += 1
|
600 |
elif isinstance(module, nn.Linear):
|
601 |
if module.weight is not None:
|
602 |
nn.init.xavier_uniform_(module.weight)
|
603 |
if module.bias is not None:
|
604 |
nn.init.zeros_(module.bias)
|
605 |
self.init_counts["Linear"] += 1
|
606 |
+
elif isinstance(module, nn.Conv1d):
|
607 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
608 |
if module.bias is not None:
|
609 |
nn.init.zeros_(module.bias)
|
610 |
self.init_counts["Conv1d"] += 1
|
611 |
+
elif isinstance(module, nn.Conv2d):
|
612 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
613 |
if module.bias is not None:
|
614 |
nn.init.zeros_(module.bias)
|
615 |
self.init_counts["Conv2d"] += 1
|
616 |
+
elif isinstance(module, residual):
|
617 |
self.init_counts["Residual"] += 1
|
618 |
elif isinstance(module, processor):
|
619 |
self.init_counts["processor"] += 1
|
|
|
625 |
for module_type, count in self.init_counts.items():
|
626 |
if count > 0:
|
627 |
print(f"{module_type}: {count}")
|
628 |
+
|
629 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
630 |
+
cache = {**cache} if cache is not None else {}
|
631 |
+
hooks = []
|
632 |
+
def save_to_cache(module, _, output):
|
633 |
+
if module not in cache or output.shape[1] > self.param.ctx:
|
634 |
+
cache[module] = output
|
635 |
+
else:
|
636 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
637 |
+
return cache[module]
|
638 |
+
|
639 |
+
def install_hooks(layer: nn.Module):
|
640 |
+
if isinstance(layer, attentiona):
|
641 |
+
hooks.append(layer.k.register_forward_hook(save_to_cache))
|
642 |
+
hooks.append(layer.v.register_forward_hook(save_to_cache))
|
643 |
+
self.processor.apply(install_hooks)
|
644 |
+
return cache, hooks
|