Update model_simple.py
Browse files- model_simple.py +1 -12
model_simple.py
CHANGED
@@ -37,11 +37,6 @@ def dict_to(d, device, dtype=dtype):
|
|
37 |
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
38 |
for k, v in d.items()}
|
39 |
|
40 |
-
def exists(v):
|
41 |
-
return v is not None
|
42 |
-
|
43 |
-
def default(v, b):
|
44 |
-
return v if exists(v) else b
|
45 |
|
46 |
class Conv1d(nn.Conv1d):
|
47 |
def _conv_forward(
|
@@ -115,12 +110,6 @@ def get_activation(act: str) -> nn.Module:
|
|
115 |
def there_is_a(val):
|
116 |
return val is not None
|
117 |
|
118 |
-
def exists(val):
|
119 |
-
return val is not None
|
120 |
-
|
121 |
-
def default(value, d):
|
122 |
-
return d if not exists(value) else value
|
123 |
-
|
124 |
def to(t):
|
125 |
return {'device': t.device, 'dtype': t.dtype}
|
126 |
|
@@ -260,7 +249,7 @@ class attentionb(nn.Module):
|
|
260 |
def _focus(self, x, xa = None, mask = None, use_win = False):
|
261 |
|
262 |
q = self.que(self.lna(x))
|
263 |
-
k, v = self.kv(self.lna(
|
264 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
265 |
_, _, ctx, _ = q.shape
|
266 |
self.scale = q.shape[-1] ** -0.35
|
|
|
37 |
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
38 |
for k, v in d.items()}
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
class Conv1d(nn.Conv1d):
|
42 |
def _conv_forward(
|
|
|
110 |
def there_is_a(val):
|
111 |
return val is not None
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def to(t):
|
114 |
return {'device': t.device, 'dtype': t.dtype}
|
115 |
|
|
|
249 |
def _focus(self, x, xa = None, mask = None, use_win = False):
|
250 |
|
251 |
q = self.que(self.lna(x))
|
252 |
+
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
|
253 |
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
|
254 |
_, _, ctx, _ = q.shape
|
255 |
self.scale = q.shape[-1] ** -0.35
|