Sin2pi commited on
Commit
bfada3c
·
verified ·
1 Parent(s): 8fdc013

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +89 -198
model_simple.py CHANGED
@@ -1,15 +1,12 @@
1
- import os
2
  import warnings
3
  import logging
4
  from itertools import chain
5
  import torch
6
- import torch.nn.functional as F
7
  from torch import nn, Tensor, einsum
8
- from typing import Optional, List, Tuple, Union
9
-
10
  import numpy as np
11
  from dataclasses import dataclass
12
-
13
  from torch.nn.functional import scaled_dot_product_attention
14
  from einops import rearrange
15
 
@@ -27,71 +24,6 @@ def sinusoids(ctx, dims, max_tscale=10000):
27
  positional_embedding = nn.Parameter(position, requires_grad=True)
28
  return positional_embedding
29
 
30
- def valid(default_value, *items):
31
- for item in items:
32
- if item is not None:
33
- return item
34
- return default_value
35
-
36
- def dict_to(d, device, dtype=dtype):
37
- return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
38
- for k, v in d.items()}
39
-
40
-
41
- class Conv1d(nn.Conv1d):
42
- def _conv_forward(
43
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
44
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
45
-
46
- class Conv2d(nn.Conv2d):
47
- def _conv_forward(
48
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
49
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
50
-
51
- class Linear(nn.Module):
52
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
53
- super(Linear, self).__init__()
54
- self.linear = nn.Linear(in_features, out_features, bias=bias)
55
- torch.nn.init.xavier_uniform_(self.linear.weight)
56
- if bias:
57
- torch.nn.init.zeros_(self.linear.bias)
58
- def forward(self, x: Tensor) -> Tensor:
59
- return self.linear(x)
60
-
61
- class RMSNorm(nn.Module):
62
- def __init__(self, dims: Union[int, Tensor, List, Tuple],
63
- eps = 1e-8, elementwise_affine = True):
64
- super(RMSNorm, self).__init__()
65
- if isinstance(dims, int):
66
- self.normalized_shape = (dims,)
67
- else:
68
- self.normalized_shape = tuple(dims)
69
- self.eps = eps
70
- self.elementwise_affine = elementwise_affine
71
- if self.elementwise_affine:
72
- self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
73
- torch.nn.init.ones_(self.weight)
74
- else:
75
- self.register_parameter("weight", None)
76
- def forward(self, x):
77
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
78
-
79
- def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
80
- weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
81
- eps: float = 1e-5) -> Tensor:
82
- return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
83
-
84
- def get_device():
85
- return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
86
-
87
- def get_dtype():
88
- return torch.float32 if torch.cuda.is_available() else torch.float64
89
-
90
- def l2norm(tensor):
91
- dtype = tensor.dtype
92
- normed = F.normalize(tensor, dim = -1)
93
- return normed.type(dtype)
94
-
95
  def get_activation(act: str) -> nn.Module:
96
  act_map = {
97
  "gelu": nn.GELU(),
@@ -110,16 +42,6 @@ def get_activation(act: str) -> nn.Module:
110
  def there_is_a(val):
111
  return val is not None
112
 
113
- def to(t):
114
- return {'device': t.device, 'dtype': t.dtype}
115
-
116
- PATH = 'E:/hf'
117
- os.environ['HF_HOME'] = PATH
118
- os.environ['HF_DATASETS_CACHE'] = PATH
119
- os.environ['TORCH_HOME'] = PATH
120
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
121
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
122
-
123
  @dataclass
124
  class Dimensions:
125
  vocab: int
@@ -187,18 +109,10 @@ def calculate_attention(q, k, v, mask=None, temp=1.0):
187
  scaled_q = q
188
  if temp != 1.0 and temp > 0:
189
  scaled_q = q * (1.0 / temp)**.5
 
190
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
191
  return out
192
 
193
- # def calculate_attention(q_norm, k_norm, v_iter, mask=None, temp=1.0):
194
- # d_k = q_norm.size(-1)
195
- # scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp)
196
- # if mask is not None:
197
- # scores = scores.masked_fill(mask == 0, float('-inf'))
198
- # attention_weights = F.softmax(scores, dim=-1)
199
- # output = torch.matmul(attention_weights, v_iter)
200
- # return output
201
-
202
  class LocalOut(nn.Module):
203
  def __init__(self, dims: int, head: int):
204
  super().__init__()
@@ -215,26 +129,26 @@ class LocalOut(nn.Module):
215
 
216
  class attentionb(nn.Module):
217
  def __init__(self, dims: int, head: int, max_iter: int = 3,
218
- threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0, use_win=False):
219
  super(attentionb, self).__init__()
220
 
221
  self.head = head
222
  self.dims = dims
223
- head_dim = dims // head
 
224
 
225
  self.que = nn.Linear(dims, dims, bias=False)
226
  self.kv = nn.Linear(dims, dims * 2, bias=False)
227
  self.out = nn.Linear(dims, dims, bias=False)
228
 
229
  self.lna = nn.LayerNorm(dims)
230
- self.lnb = nn.LayerNorm(head_dim)
231
  self.rope = rotary(dims, head)
232
- self.use_win = use_win
233
 
234
  self.max_iter = max_iter
235
- self.threshold = nn.Parameter(torch.tensor(threshold))
236
  self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
237
- self.factor = nn.Parameter(torch.tensor(factor))
238
  self.local = LocalOut(dims, head)
239
 
240
  def update_win(self, win_size=None):
@@ -246,109 +160,109 @@ class attentionb(nn.Module):
246
  return win_size
247
  return None
248
 
249
- def _focus(self, x, xa = None, mask = None, use_win = False):
250
 
251
  q = self.que(self.lna(x))
252
- k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
253
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
254
- _, _, ctx, _ = q.shape
255
  self.scale = q.shape[-1] ** -0.35
256
  q = self.rope(q)
257
  k = self.rope(k)
258
 
259
- if use_win:
260
- iteration = 0
261
- temp = self.temp.item()
262
- prev_out = torch.zeros_like(q)
263
- attn_out = torch.zeros_like(q)
264
- threshold = self.threshold.item()
265
- factor = self.factor.item()
266
- curq = q
267
-
268
- while iteration < self.max_iter:
269
- eff_span = min(curq.shape[1], k.shape[1])
270
- if xa is not None:
271
- eff_span = min(eff_span, xa.shape[1])
272
- if eff_span == 0:
273
- break
274
-
275
- qiter = curq[:, :, :eff_span, :]
276
- kiter = k[:, :, :eff_span, :]
277
- viter = v[:, :, :eff_span, :]
278
- q = self.local.q_hd(qiter)
279
- k = self.local.k_hd(kiter)
280
- v = self.local.v_hd(viter)
281
-
282
- iter_mask = None
283
- if mask is not None:
284
- if mask.dim() == 4:
285
- iter_mask = mask[:, :, :eff_span, :eff_span]
286
- elif mask.dim() == 2:
287
- iter_mask = mask[:eff_span, :eff_span]
288
-
289
- attn_iter = calculate_attention(
290
- self.lnb(q), self.lnb(k), v,
291
- mask=iter_mask, temp=temp)
292
-
293
- iter_out = torch.zeros_like(curq)
294
- iter_out[:, :, :eff_span, :] = attn_iter
295
- diff = torch.abs(iter_out - prev_out).mean()
296
- dthresh = threshold + factor * diff
297
- if diff < dthresh and iteration > 0:
298
- attn_out = iter_out
299
- break
300
-
301
- prev_out = iter_out.clone()
302
- curq = curq + iter_out
303
  attn_out = iter_out
304
- iteration += 1
305
- temp += 0.005
306
- else:
307
- attn_out = scaled_dot_product_attention(self.lnc(q), self.lnd(k), v, is_causal=mask is not None and ctx >1)
308
- wv = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
309
- qk=None
310
- return self.out(wv), qk
 
 
 
 
 
311
 
312
- def _slide_win_local(self, x, mask = None, use_win = True) -> Tensor:
 
 
 
313
 
314
  win = self.update_win()
315
- win_size = win if win is not None else 128
316
- span_len = win_size + win_size // 10
317
 
318
  _, ctx, _ = x.shape
319
- output = torch.zeros_like(x)
320
  windows = (ctx + win_size - 1) // win_size
321
 
322
  for i in range(windows):
323
  qstart = i * win_size
324
  qend = min(qstart + win_size, ctx)
325
- win_qlen = qend - qstart
326
- if win_qlen == 0:
327
  continue
328
 
329
  kstart = max(0, qend - span_len)
330
- kend = qend
331
  qwin = x[:, qstart:qend, :]
332
- kwin = x[:, kstart:kend, :]
333
 
334
  win_mask = None
335
  if mask is not None:
336
  if mask.dim() == 4:
337
- win_mask = mask[:, :, qstart:qend, kstart:kend]
338
  elif mask.dim() == 2:
339
- win_mask = mask[qstart:qend, kstart:kend]
340
-
341
- attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask, use_win=True)
342
- output[:, qstart:qend, :] = attn_out
343
- return output
344
 
345
- def forward(self, x, xa = None, mask = None, use_win: bool = False):
346
- if use_win:
347
- return self._slide_win_local(x, mask, use_win=True)
348
- else:
349
- output, _ = self._focus(x, xa, mask, use_win=False)
350
- return output
351
 
 
 
 
 
 
 
352
  class attentiona(nn.Module):
353
  def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
354
  super().__init__()
@@ -370,8 +284,8 @@ class attentiona(nn.Module):
370
 
371
  q = self.que(self.ln(x))
372
  k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
373
- q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
374
 
 
375
  scale = q.shape[-1] ** -0.5
376
 
377
  q = self.rope(q)
@@ -389,26 +303,6 @@ class attentiona(nn.Module):
389
  wv = rearrange(wv, 'b h c d -> b c (h d)')
390
  out = self.out(wv)
391
 
392
- if self.cross_talk:
393
- qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
394
- qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
395
- qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
396
- qk = einsum('b h i d, b h j d -> b h i j', qk, qka)
397
- if there_is_a(mask):
398
- i, j = qk.shape[-2:]
399
- mask = torch.ones(i, j, device=device, dtype=torch.bool).triu(j - i + 1)
400
- qk = qk.masked_fill(mask, -torch.finfo(qk.dtype).max)
401
- x = qk.softmax(dim = -1)
402
- xa = qk.softmax(dim = -2)
403
- x = self.x(x)
404
- xa = self.xa(xa)
405
- x = einsum('b h i j, b h j d -> b h i d', x, va)
406
- xa = einsum('b h j i, b h j d -> b h i d', xa, v)
407
- x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
408
- out = self.out(x) # outxa = self.out(xa)
409
-
410
- return out, qk
411
-
412
  class attentiond(nn.Module):
413
  def __init__(self, dims: int, head: int):
414
  super().__init__()
@@ -464,11 +358,10 @@ class residual(nn.Module):
464
 
465
  self.lna = nn.LayerNorm(dims, bias=False)
466
  self.atta = attentiona(dims, head)
467
- self.attb = attentiona(dims, head)
468
  # self.attc = attentiona(dims, head, cross_talk=True)
469
- self.attb = attentionb(dims, head, max_iter=1, use_win=True, temp=1.0)
470
- self.tgate = tgate(dims, num_types=4)
471
 
 
472
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
473
 
474
  def forward(
@@ -479,9 +372,7 @@ class residual(nn.Module):
479
  ):
480
  x = x + self.atta(x, mask=mask)[0]
481
  if xa is not None:
482
- x = x + self.attb(x, xa, mask=None)[0]
483
- # x = x + self.attc(x, xa, mask=None)[0]
484
- # x = x + self.attd(x, xa, mask=None)[0]
485
  x = x + self.tgate(x)
486
  x = x + self.mlp(self.lna(x))
487
  return x
@@ -507,7 +398,7 @@ class processor(nn.Module):
507
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
508
  self.register_buffer("mask", mask, persistent=False)
509
 
510
- def forward(self, x, xa, sequential=False, modal=True, kv_cache=None) -> Tensor:
511
 
512
  if xa.dim() == 2:
513
  xa = xa.unsqueeze(0)
@@ -554,7 +445,7 @@ class Model(nn.Module):
554
  module.update_win(win_size)
555
 
556
  def adjust_window(self, loss, ctx):
557
- self.win_size = ((ctx // self.param.head)) #based on loss but not on the graph itself which is the idea
558
  if loss < self.best_loss:
559
  win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
560
  else:
 
1
+
2
  import warnings
3
  import logging
4
  from itertools import chain
5
  import torch
 
6
  from torch import nn, Tensor, einsum
7
+ from typing import Optional
 
8
  import numpy as np
9
  from dataclasses import dataclass
 
10
  from torch.nn.functional import scaled_dot_product_attention
11
  from einops import rearrange
12
 
 
24
  positional_embedding = nn.Parameter(position, requires_grad=True)
25
  return positional_embedding
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def get_activation(act: str) -> nn.Module:
28
  act_map = {
29
  "gelu": nn.GELU(),
 
42
  def there_is_a(val):
43
  return val is not None
44
 
 
 
 
 
 
 
 
 
 
 
45
  @dataclass
46
  class Dimensions:
47
  vocab: int
 
109
  scaled_q = q
110
  if temp != 1.0 and temp > 0:
111
  scaled_q = q * (1.0 / temp)**.5
112
+ print(temp)
113
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
114
  return out
115
 
 
 
 
 
 
 
 
 
 
116
  class LocalOut(nn.Module):
117
  def __init__(self, dims: int, head: int):
118
  super().__init__()
 
129
 
130
  class attentionb(nn.Module):
131
  def __init__(self, dims: int, head: int, max_iter: int = 3,
132
+ threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
133
  super(attentionb, self).__init__()
134
 
135
  self.head = head
136
  self.dims = dims
137
+ self.head_dim = dims // head
138
+ self.win = 0
139
 
140
  self.que = nn.Linear(dims, dims, bias=False)
141
  self.kv = nn.Linear(dims, dims * 2, bias=False)
142
  self.out = nn.Linear(dims, dims, bias=False)
143
 
144
  self.lna = nn.LayerNorm(dims)
145
+ self.lnb = nn.LayerNorm(self.head_dim)
146
  self.rope = rotary(dims, head)
 
147
 
148
  self.max_iter = max_iter
149
+ self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
150
  self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
151
+ self.factor = nn.Parameter(torch.tensor(factor), requires_grad=True)
152
  self.local = LocalOut(dims, head)
153
 
154
  def update_win(self, win_size=None):
 
160
  return win_size
161
  return None
162
 
163
+ def _focus(self, x, xa = None, mask = None, win_size=None):
164
 
165
  q = self.que(self.lna(x))
166
+ k, v = self.kv(self.lna(x if not xa else xa)).chunk(2, dim=-1)
167
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
168
+
169
  self.scale = q.shape[-1] ** -0.35
170
  q = self.rope(q)
171
  k = self.rope(k)
172
 
173
+ iteration = 0
174
+ temp = self.temp
175
+ prev_out = torch.zeros_like(q)
176
+ attn_out = torch.zeros_like(q)
177
+ threshold = self.threshold
178
+ factor = self.factor
179
+ curq = q #if curq is None else curq
180
+
181
+ while iteration < self.max_iter:
182
+ eff_span = min(curq.shape[1], k.shape[1])
183
+ if xa is not None:
184
+ eff_span = min(eff_span, xa.shape[1])
185
+ if eff_span == 0:
186
+ break
187
+
188
+ qiter = curq[:, :, :eff_span, :]
189
+ kiter = k[:, :, :eff_span, :]
190
+ viter = v[:, :, :eff_span, :]
191
+ q = self.local.q_hd(qiter)
192
+ k = self.local.k_hd(kiter)
193
+ v = self.local.v_hd(viter)
194
+
195
+ iter_mask = None
196
+ if mask is not None:
197
+ if mask.dim() == 4:
198
+ iter_mask = mask[:, :, :eff_span, :eff_span]
199
+ elif mask.dim() == 2:
200
+ iter_mask = mask[:eff_span, :eff_span]
201
+
202
+ attn_iter = calculate_attention(
203
+ self.lnb(q), self.lnb(k), v,
204
+ mask=iter_mask, temp=temp)
205
+
206
+ iter_out = torch.zeros_like(curq)
207
+ iter_out[:, :, :eff_span, :] = attn_iter
208
+ diff = torch.abs(iter_out - prev_out).mean()
209
+ dthresh = threshold + factor * diff
210
+ if diff < dthresh and iteration > 0:
 
 
 
 
 
 
211
  attn_out = iter_out
212
+ break
213
+
214
+ prev_out = iter_out.clone()
215
+ curq = curq + iter_out
216
+ attn_out = iter_out
217
+ # if win_size is not None:
218
+ # if win_size > self.win:
219
+ # temp += 0.005
220
+ # else:
221
+ # temp -= 0.005
222
+ # self.win = win_size
223
+ iteration += 1
224
 
225
+ out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
226
+ return out
227
+
228
+ def _slide_win_local(self, x, mask = None) -> Tensor:
229
 
230
  win = self.update_win()
231
+ win_size = win if win is not None else self.head_dim
232
+ span_len = win_size + win_size // self.head
233
 
234
  _, ctx, _ = x.shape
235
+ out = torch.zeros_like(x)
236
  windows = (ctx + win_size - 1) // win_size
237
 
238
  for i in range(windows):
239
  qstart = i * win_size
240
  qend = min(qstart + win_size, ctx)
241
+ qlen = qend - qstart
242
+ if qlen == 0:
243
  continue
244
 
245
  kstart = max(0, qend - span_len)
 
246
  qwin = x[:, qstart:qend, :]
247
+ kwin = x[:, kstart:qend, :]
248
 
249
  win_mask = None
250
  if mask is not None:
251
  if mask.dim() == 4:
252
+ win_mask = mask[:, :, qstart:qend, kstart:qend]
253
  elif mask.dim() == 2:
254
+ win_mask = mask[qstart:qend, kstart:qend]
 
 
 
 
255
 
256
+ attn_out = self._focus(x=qwin, xa=kwin, mask=win_mask, win_size=win_size)
257
+ out[:, qstart:qend, :] = attn_out
258
+ return out
 
 
 
259
 
260
+ def forward(self, x, xa = None, mask = None):
261
+ x = self._slide_win_local(x, mask=None)
262
+ xa = self._slide_win_local(xa, mask=None)
263
+ output = self._focus(x, xa, mask=None)
264
+ return self.out(output)
265
+
266
  class attentiona(nn.Module):
267
  def __init__(self, dims: int, head: int, dropout_rate: float = 0.1, cross_talk=False):
268
  super().__init__()
 
284
 
285
  q = self.que(self.ln(x))
286
  k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
 
287
 
288
+ q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
289
  scale = q.shape[-1] ** -0.5
290
 
291
  q = self.rope(q)
 
303
  wv = rearrange(wv, 'b h c d -> b c (h d)')
304
  out = self.out(wv)
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  class attentiond(nn.Module):
307
  def __init__(self, dims: int, head: int):
308
  super().__init__()
 
358
 
359
  self.lna = nn.LayerNorm(dims, bias=False)
360
  self.atta = attentiona(dims, head)
361
+ self.attb = attentionb(dims, head, max_iter=1)
362
  # self.attc = attentiona(dims, head, cross_talk=True)
 
 
363
 
364
+ self.tgate = tgate(dims, num_types=4)
365
  self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
366
 
367
  def forward(
 
372
  ):
373
  x = x + self.atta(x, mask=mask)[0]
374
  if xa is not None:
375
+ x = x + self.attb(x, xa, mask=None)
 
 
376
  x = x + self.tgate(x)
377
  x = x + self.mlp(self.lna(x))
378
  return x
 
398
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
399
  self.register_buffer("mask", mask, persistent=False)
400
 
401
+ def forward(self, x, xa, sequential=False, modal=False, kv_cache=None) -> Tensor:
402
 
403
  if xa.dim() == 2:
404
  xa = xa.unsqueeze(0)
 
445
  module.update_win(win_size)
446
 
447
  def adjust_window(self, loss, ctx):
448
+ self.win_size = ((ctx // self.param.head))
449
  if loss < self.best_loss:
450
  win_size = (self.win_size * self.factor) #.clamp(0, ctx - 1)
451
  else: