Update model_simple.py
Browse files- model_simple.py +89 -198
model_simple.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1 |
-
|
2 |
import warnings
|
3 |
import logging
|
4 |
from itertools import chain
|
5 |
import torch
|
6 |
-
import torch.nn.functional as F
|
7 |
from torch import nn, Tensor, einsum
|
8 |
-
from typing import Optional
|
9 |
-
|
10 |
import numpy as np
|
11 |
from dataclasses import dataclass
|
12 |
-
|
13 |
from torch.nn.functional import scaled_dot_product_attention
|
14 |
from einops import rearrange
|
15 |
|
@@ -27,71 +24,6 @@ def sinusoids(ctx, dims, max_tscale=10000):
|
|
27 |
positional_embedding = nn.Parameter(position, requires_grad=True)
|
28 |
return positional_embedding
|
29 |
|
30 |
-
def valid(default_value, *items):
|
31 |
-
for item in items:
|
32 |
-
if item is not None:
|
33 |
-
return item
|
34 |
-
return default_value
|
35 |
-
|
36 |
-
def dict_to(d, device, dtype=dtype):
|
37 |
-
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
38 |
-
for k, v in d.items()}
|
39 |
-
|
40 |
-
|
41 |
-
class Conv1d(nn.Conv1d):
|
42 |
-
def _conv_forward(
|
43 |
-
self, x: Tensor, weight: Tensor, bias) -> Tensor:
|
44 |
-
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
|
45 |
-
|
46 |
-
class Conv2d(nn.Conv2d):
|
47 |
-
def _conv_forward(
|
48 |
-
self, x: Tensor, weight: Tensor, bias) -> Tensor:
|
49 |
-
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
|
50 |
-
|
51 |
-
class Linear(nn.Module):
|
52 |
-
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
53 |
-
super(Linear, self).__init__()
|
54 |
-
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
55 |
-
torch.nn.init.xavier_uniform_(self.linear.weight)
|
56 |
-
if bias:
|
57 |
-
torch.nn.init.zeros_(self.linear.bias)
|
58 |
-
def forward(self, x: Tensor) -> Tensor:
|
59 |
-
return self.linear(x)
|
60 |
-
|
61 |
-
class RMSNorm(nn.Module):
|
62 |
-
def __init__(self, dims: Union[int, Tensor, List, Tuple],
|
63 |
-
eps = 1e-8, elementwise_affine = True):
|
64 |
-
super(RMSNorm, self).__init__()
|
65 |
-
if isinstance(dims, int):
|
66 |
-
self.normalized_shape = (dims,)
|
67 |
-
else:
|
68 |
-
self.normalized_shape = tuple(dims)
|
69 |
-
self.eps = eps
|
70 |
-
self.elementwise_affine = elementwise_affine
|
71 |
-
if self.elementwise_affine:
|
72 |
-
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
|
73 |
-
torch.nn.init.ones_(self.weight)
|
74 |
-
else:
|
75 |
-
self.register_parameter("weight", None)
|
76 |
-
def forward(self, x):
|
77 |
-
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
|
78 |
-
|
79 |
-
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
80 |
-
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
81 |
-
eps: float = 1e-5) -> Tensor:
|
82 |
-
return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
|
83 |
-
|
84 |
-
def get_device():
|
85 |
-
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
86 |
-
|
87 |
-
def get_dtype():
|
88 |
-
return torch.float32 if torch.cuda.is_available() else torch.float64
|
89 |
-
|
90 |
-
def l2norm(tensor):
|
91 |
-
dtype = tensor.dtype
|
92 |
-
normed = F.normalize(tensor, dim = -1)
|
93 |
-
return normed.type(dtype)
|
94 |
-
|
95 |
def get_activation(act: str) -> nn.Module:
|
96 |
act_map = {
|
97 |
"gelu": nn.GELU(),
|
@@ -110,16 +42,6 @@ def get_activation(act: str) -> nn.Module:
|
|
110 |
def there_is_a(val):
|
111 |
return val is not None
|
112 |
|
113 |
-
def to(t):
|
114 |
-
return {'device': t.device, 'dtype': t.dtype}
|
115 |
-
|
116 |
-
PATH = 'E:/hf'
|
117 |
-
os.environ['HF_HOME'] = PATH
|
118 |
-
os.environ['HF_DATASETS_CACHE'] = PATH
|
119 |
-
os.environ['TORCH_HOME'] = PATH
|
120 |
-
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
121 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
122 |
-
|
123 |
@dataclass
|
124 |
class Dimensions:
|
125 |
vocab: int
|
@@ -187,18 +109,10 @@ def calculate_attention(q, k, v, mask=None, temp=1.0):
|
|
187 |
scaled_q = q
|
188 |
if temp != 1.0 and temp > 0:
|
189 |
scaled_q = q * (1.0 / temp)**.5
|
|
|
190 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
191 |
return out
|
192 |
|
193 |
-
# def calculate_attention(q_norm, k_norm, v_iter, mask=None, temp=1.0):
|
194 |
-
# d_k = q_norm.size(-1)
|
195 |
-
# scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp)
|
196 |
-
# if mask is not None:
|
197 |
-
# scores = scores.masked_fill(mask == 0, float('-inf'))
|
198 |
-
# attention_weights = F.softmax(scores, dim=-1)
|
199 |
-
# output = torch.matmul(attention_weights, v_iter)
|
200 |
-
# return output
|
201 |
-
|
202 |
class LocalOut(nn.Module):
|
203 |
def __init__(self, dims: int, head: int):
|
204 |
super().__init__()
|
@@ -215,26 +129,26 @@ class LocalOut(nn.Module):
|
|
215 |
|
216 |
class attentionb(nn.Module):
|
217 |
def __init__(self, dims: int, head: int, max_iter: int = 3,
|
218 |
-
threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0
|
219 |
super(attentionb, self).__init__()
|
220 |
|
221 |
self.head = head
|
222 |
self.dims = dims
|
223 |
-
head_dim = dims // head
|
|
|
224 |
|
225 |
self.que = nn.Linear(dims, dims, bias=False)
|
226 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
227 |
self.out = nn.Linear(dims, dims, bias=False)
|
228 |
|
229 |
self.lna = nn.LayerNorm(dims)
|
230 |
-
self.lnb = nn.LayerNorm(head_dim)
|
231 |
self.rope = rotary(dims, head)
|
232 |
-
self.use_win = use_win
|
233 |
|
234 |
self.max_iter = max_iter
|
235 |
-
self.threshold = nn.Parameter(torch.tensor(threshold))
|
236 |
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
237 |
-
self.factor = nn.Parameter(torch.tensor(factor))
|
238 |
self.local = LocalOut(dims, head)
|
239 |
|
240 |
def update_win(self, win_size=None):
|
@@ -246,109 +160,109 @@ class attentionb(nn.Module):
|
|
246 |
return win_size
|
247 |
return None
|
248 |
|
249 |
-
def _focus(self, x, xa = None, mask = None,
|
250 |
|
251 |
q = self.que(self.lna(x))
|
252 |
-
k, v = self.kv(self.lna(x if xa
|
253 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
254 |
-
|
255 |
self.scale = q.shape[-1] ** -0.35
|
256 |
q = self.rope(q)
|
257 |
k = self.rope(k)
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
if mask
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
if diff < dthresh and iteration > 0:
|
298 |
-
attn_out = iter_out
|
299 |
-
break
|
300 |
-
|
301 |
-
prev_out = iter_out.clone()
|
302 |
-
curq = curq + iter_out
|
303 |
attn_out = iter_out
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
-
|
|
|
|
|
|
|
313 |
|
314 |
win = self.update_win()
|
315 |
-
win_size = win if win is not None else
|
316 |
-
span_len = win_size + win_size //
|
317 |
|
318 |
_, ctx, _ = x.shape
|
319 |
-
|
320 |
windows = (ctx + win_size - 1) // win_size
|
321 |
|
322 |
for i in range(windows):
|
323 |
qstart = i * win_size
|
324 |
qend = min(qstart + win_size, ctx)
|
325 |
-
|
326 |
-
if
|
327 |
continue
|
328 |
|
329 |
kstart = max(0, qend - span_len)
|
330 |
-
kend = qend
|
331 |
qwin = x[:, qstart:qend, :]
|
332 |
-
kwin = x[:, kstart:
|
333 |
|
334 |
win_mask = None
|
335 |
if mask is not None:
|
336 |
if mask.dim() == 4:
|
337 |
-
win_mask = mask[:, :, qstart:qend, kstart:
|
338 |
elif mask.dim() == 2:
|
339 |
-
win_mask = mask[qstart:qend, kstart:
|
340 |
-
|
341 |
-
attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask, use_win=True)
|
342 |
-
output[:, qstart:qend, :] = attn_out
|
343 |
-
return output
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
else:
|
349 |
-
output, _ = self._focus(x, xa, mask, use_win=False)
|
350 |
-
return output
|
351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
class attentiona(nn.Module):
|
353 |
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
|
354 |
super().__init__()
|
@@ -370,8 +284,8 @@ class attentiona(nn.Module):
|
|
370 |
|
371 |
q = self.que(self.ln(x))
|
372 |
k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
373 |
-
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
374 |
|
|
|
375 |
scale = q.shape[-1] ** -0.5
|
376 |
|
377 |
q = self.rope(q)
|
@@ -389,26 +303,6 @@ class attentiona(nn.Module):
|
|
389 |
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
390 |
out = self.out(wv)
|
391 |
|
392 |
-
if self.cross_talk:
|
393 |
-
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
|
394 |
-
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
395 |
-
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
|
396 |
-
qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
|
397 |
-
if there_is_a(mask):
|
398 |
-
i, j = qk.shape[-2:]
|
399 |
-
mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
|
400 |
-
qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
|
401 |
-
x = qk.softmax(dim = -1)
|
402 |
-
xa = qk.softmax(dim = -2)
|
403 |
-
x = self.x(x)
|
404 |
-
xa = self.xa(xa)
|
405 |
-
x = einsum('b h i j, b h j d -> b h i d', x, va)
|
406 |
-
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
|
407 |
-
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
|
408 |
-
out = self.out(x) # outxa = self.out(xa)
|
409 |
-
|
410 |
-
return out, qk
|
411 |
-
|
412 |
class attentiond(nn.Module):
|
413 |
def __init__(self, dims: int, head: int):
|
414 |
super().__init__()
|
@@ -464,11 +358,10 @@ class residual(nn.Module):
|
|
464 |
|
465 |
self.lna = nn.LayerNorm(dims, bias=False)
|
466 |
self.atta = attentiona(dims, head)
|
467 |
-
self.attb =
|
468 |
# self.attc = attentiona(dims, head, cross_talk=True)
|
469 |
-
self.attb = attentionb(dims, head, max_iter=1, use_win=True, temp=1.0)
|
470 |
-
self.tgate = tgate(dims, num_types=4)
|
471 |
|
|
|
472 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
473 |
|
474 |
def forward(
|
@@ -479,9 +372,7 @@ class residual(nn.Module):
|
|
479 |
):
|
480 |
x = x + self.atta(x, mask=mask)[0]
|
481 |
if xa is not None:
|
482 |
-
x = x + self.attb(x, xa, mask=None)
|
483 |
-
# x = x + self.attc(x, xa, mask=None)[0]
|
484 |
-
# x = x + self.attd(x, xa, mask=None)[0]
|
485 |
x = x + self.tgate(x)
|
486 |
x = x + self.mlp(self.lna(x))
|
487 |
return x
|
@@ -507,7 +398,7 @@ class processor(nn.Module):
|
|
507 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
508 |
self.register_buffer("mask", mask, persistent=False)
|
509 |
|
510 |
-
def forward(self, x, xa, sequential=False, modal=
|
511 |
|
512 |
if xa.dim() == 2:
|
513 |
xa = xa.unsqueeze(0)
|
@@ -554,7 +445,7 @@ class Model(nn.Module):
|
|
554 |
module.update_win(win_size)
|
555 |
|
556 |
def adjust_window(self, loss, ctx):
|
557 |
-
self.win_size = ((ctx // self.param.head))
|
558 |
if loss < self.best_loss:
|
559 |
win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
|
560 |
else:
|
|
|
1 |
+
|
2 |
import warnings
|
3 |
import logging
|
4 |
from itertools import chain
|
5 |
import torch
|
|
|
6 |
from torch import nn, Tensor, einsum
|
7 |
+
from typing import Optional
|
|
|
8 |
import numpy as np
|
9 |
from dataclasses import dataclass
|
|
|
10 |
from torch.nn.functional import scaled_dot_product_attention
|
11 |
from einops import rearrange
|
12 |
|
|
|
24 |
positional_embedding = nn.Parameter(position, requires_grad=True)
|
25 |
return positional_embedding
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def get_activation(act: str) -> nn.Module:
|
28 |
act_map = {
|
29 |
"gelu": nn.GELU(),
|
|
|
42 |
def there_is_a(val):
|
43 |
return val is not None
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
@dataclass
|
46 |
class Dimensions:
|
47 |
vocab: int
|
|
|
109 |
scaled_q = q
|
110 |
if temp != 1.0 and temp > 0:
|
111 |
scaled_q = q * (1.0 / temp)**.5
|
112 |
+
print(temp)
|
113 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
114 |
return out
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
class LocalOut(nn.Module):
|
117 |
def __init__(self, dims: int, head: int):
|
118 |
super().__init__()
|
|
|
129 |
|
130 |
class attentionb(nn.Module):
|
131 |
def __init__(self, dims: int, head: int, max_iter: int = 3,
|
132 |
+
threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
|
133 |
super(attentionb, self).__init__()
|
134 |
|
135 |
self.head = head
|
136 |
self.dims = dims
|
137 |
+
self.head_dim = dims // head
|
138 |
+
self.win = 0
|
139 |
|
140 |
self.que = nn.Linear(dims, dims, bias=False)
|
141 |
self.kv = nn.Linear(dims, dims * 2, bias=False)
|
142 |
self.out = nn.Linear(dims, dims, bias=False)
|
143 |
|
144 |
self.lna = nn.LayerNorm(dims)
|
145 |
+
self.lnb = nn.LayerNorm(self.head_dim)
|
146 |
self.rope = rotary(dims, head)
|
|
|
147 |
|
148 |
self.max_iter = max_iter
|
149 |
+
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
|
150 |
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
|
151 |
+
self.factor = nn.Parameter(torch.tensor(factor), requires_grad=True)
|
152 |
self.local = LocalOut(dims, head)
|
153 |
|
154 |
def update_win(self, win_size=None):
|
|
|
160 |
return win_size
|
161 |
return None
|
162 |
|
163 |
+
def _focus(self, x, xa = None, mask = None, win_size=None):
|
164 |
|
165 |
q = self.que(self.lna(x))
|
166 |
+
k, v = self.kv(self.lna(x if not xa else xa)).chunk(2, dim=-1)
|
167 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
168 |
+
|
169 |
self.scale = q.shape[-1] ** -0.35
|
170 |
q = self.rope(q)
|
171 |
k = self.rope(k)
|
172 |
|
173 |
+
iteration = 0
|
174 |
+
temp = self.temp
|
175 |
+
prev_out = torch.zeros_like(q)
|
176 |
+
attn_out = torch.zeros_like(q)
|
177 |
+
threshold = self.threshold
|
178 |
+
factor = self.factor
|
179 |
+
curq = q #if curq is None else curq
|
180 |
+
|
181 |
+
while iteration < self.max_iter:
|
182 |
+
eff_span = min(curq.shape[1], k.shape[1])
|
183 |
+
if xa is not None:
|
184 |
+
eff_span = min(eff_span, xa.shape[1])
|
185 |
+
if eff_span == 0:
|
186 |
+
break
|
187 |
+
|
188 |
+
qiter = curq[:, :, :eff_span, :]
|
189 |
+
kiter = k[:, :, :eff_span, :]
|
190 |
+
viter = v[:, :, :eff_span, :]
|
191 |
+
q = self.local.q_hd(qiter)
|
192 |
+
k = self.local.k_hd(kiter)
|
193 |
+
v = self.local.v_hd(viter)
|
194 |
+
|
195 |
+
iter_mask = None
|
196 |
+
if mask is not None:
|
197 |
+
if mask.dim() == 4:
|
198 |
+
iter_mask = mask[:, :, :eff_span, :eff_span]
|
199 |
+
elif mask.dim() == 2:
|
200 |
+
iter_mask = mask[:eff_span, :eff_span]
|
201 |
+
|
202 |
+
attn_iter = calculate_attention(
|
203 |
+
self.lnb(q), self.lnb(k), v,
|
204 |
+
mask=iter_mask, temp=temp)
|
205 |
+
|
206 |
+
iter_out = torch.zeros_like(curq)
|
207 |
+
iter_out[:, :, :eff_span, :] = attn_iter
|
208 |
+
diff = torch.abs(iter_out - prev_out).mean()
|
209 |
+
dthresh = threshold + factor * diff
|
210 |
+
if diff < dthresh and iteration > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
attn_out = iter_out
|
212 |
+
break
|
213 |
+
|
214 |
+
prev_out = iter_out.clone()
|
215 |
+
curq = curq + iter_out
|
216 |
+
attn_out = iter_out
|
217 |
+
# if win_size is not None:
|
218 |
+
# if win_size > self.win:
|
219 |
+
# temp += 0.005
|
220 |
+
# else:
|
221 |
+
# temp -= 0.005
|
222 |
+
# self.win = win_size
|
223 |
+
iteration += 1
|
224 |
|
225 |
+
out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
226 |
+
return out
|
227 |
+
|
228 |
+
def _slide_win_local(self, x, mask = None) -> Tensor:
|
229 |
|
230 |
win = self.update_win()
|
231 |
+
win_size = win if win is not None else self.head_dim
|
232 |
+
span_len = win_size + win_size // self.head
|
233 |
|
234 |
_, ctx, _ = x.shape
|
235 |
+
out = torch.zeros_like(x)
|
236 |
windows = (ctx + win_size - 1) // win_size
|
237 |
|
238 |
for i in range(windows):
|
239 |
qstart = i * win_size
|
240 |
qend = min(qstart + win_size, ctx)
|
241 |
+
qlen = qend - qstart
|
242 |
+
if qlen == 0:
|
243 |
continue
|
244 |
|
245 |
kstart = max(0, qend - span_len)
|
|
|
246 |
qwin = x[:, qstart:qend, :]
|
247 |
+
kwin = x[:, kstart:qend, :]
|
248 |
|
249 |
win_mask = None
|
250 |
if mask is not None:
|
251 |
if mask.dim() == 4:
|
252 |
+
win_mask = mask[:, :, qstart:qend, kstart:qend]
|
253 |
elif mask.dim() == 2:
|
254 |
+
win_mask = mask[qstart:qend, kstart:qend]
|
|
|
|
|
|
|
|
|
255 |
|
256 |
+
attn_out = self._focus(x=qwin, xa=kwin, mask=win_mask, win_size=win_size)
|
257 |
+
out[:, qstart:qend, :] = attn_out
|
258 |
+
return out
|
|
|
|
|
|
|
259 |
|
260 |
+
def forward(self, x, xa = None, mask = None):
|
261 |
+
x = self._slide_win_local(x, mask=None)
|
262 |
+
xa = self._slide_win_local(xa, mask=None)
|
263 |
+
output = self._focus(x, xa, mask=None)
|
264 |
+
return self.out(output)
|
265 |
+
|
266 |
class attentiona(nn.Module):
|
267 |
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
|
268 |
super().__init__()
|
|
|
284 |
|
285 |
q = self.que(self.ln(x))
|
286 |
k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
|
|
|
287 |
|
288 |
+
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
289 |
scale = q.shape[-1] ** -0.5
|
290 |
|
291 |
q = self.rope(q)
|
|
|
303 |
wv = rearrange(wv, 'b h c d -> b c (h d)')
|
304 |
out = self.out(wv)
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
class attentiond(nn.Module):
|
307 |
def __init__(self, dims: int, head: int):
|
308 |
super().__init__()
|
|
|
358 |
|
359 |
self.lna = nn.LayerNorm(dims, bias=False)
|
360 |
self.atta = attentiona(dims, head)
|
361 |
+
self.attb = attentionb(dims, head, max_iter=1)
|
362 |
# self.attc = attentiona(dims, head, cross_talk=True)
|
|
|
|
|
363 |
|
364 |
+
self.tgate = tgate(dims, num_types=4)
|
365 |
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
|
366 |
|
367 |
def forward(
|
|
|
372 |
):
|
373 |
x = x + self.atta(x, mask=mask)[0]
|
374 |
if xa is not None:
|
375 |
+
x = x + self.attb(x, xa, mask=None)
|
|
|
|
|
376 |
x = x + self.tgate(x)
|
377 |
x = x + self.mlp(self.lna(x))
|
378 |
return x
|
|
|
398 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
399 |
self.register_buffer("mask", mask, persistent=False)
|
400 |
|
401 |
+
def forward(self, x, xa, sequential=False, modal=False, kv_cache=None) -> Tensor:
|
402 |
|
403 |
if xa.dim() == 2:
|
404 |
xa = xa.unsqueeze(0)
|
|
|
445 |
module.update_win(win_size)
|
446 |
|
447 |
def adjust_window(self, loss, ctx):
|
448 |
+
self.win_size = ((ctx // self.param.head))
|
449 |
if loss < self.best_loss:
|
450 |
win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
|
451 |
else:
|