Sin2pi commited on
Commit
87c3f87
·
verified ·
1 Parent(s): c12de10

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +1 -12
model_simple.py CHANGED
@@ -37,11 +37,6 @@ def dict_to(d, device, dtype=dtype):
37
  return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
38
  for k, v in d.items()}
39
 
40
- def exists(v):
41
- return v is not None
42
-
43
- def default(v, b):
44
- return v if exists(v) else b
45
 
46
  class Conv1d(nn.Conv1d):
47
  def _conv_forward(
@@ -115,12 +110,6 @@ def get_activation(act: str) -> nn.Module:
115
  def there_is_a(val):
116
  return val is not None
117
 
118
- def exists(val):
119
- return val is not None
120
-
121
- def default(value, d):
122
- return d if not exists(value) else value
123
-
124
  def to(t):
125
  return {'device': t.device, 'dtype': t.dtype}
126
 
@@ -260,7 +249,7 @@ class attentionb(nn.Module):
260
  def _focus(self, x, xa = None, mask = None, use_win = False):
261
 
262
  q = self.que(self.lna(x))
263
- k, v = self.kv(self.lna(default(xa, x))).chunk(2, dim=-1)
264
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
265
  _, _, ctx, _ = q.shape
266
  self.scale = q.shape[-1] ** -0.35
 
37
  return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
38
  for k, v in d.items()}
39
 
 
 
 
 
 
40
 
41
  class Conv1d(nn.Conv1d):
42
  def _conv_forward(
 
110
  def there_is_a(val):
111
  return val is not None
112
 
 
 
 
 
 
 
113
  def to(t):
114
  return {'device': t.device, 'dtype': t.dtype}
115
 
 
249
  def _focus(self, x, xa = None, mask = None, use_win = False):
250
 
251
  q = self.que(self.lna(x))
252
+ k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
253
  q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
254
  _, _, ctx, _ = q.shape
255
  self.scale = q.shape[-1] ** -0.35