trying to fix test on CI
Browse files
tests/modules/test_transformer.py
CHANGED
|
@@ -132,8 +132,8 @@ def test_attention_as_float32():
|
|
| 132 |
|
| 133 |
@torch.no_grad()
|
| 134 |
def test_streaming_memory_efficient():
|
| 135 |
-
torch.manual_seed(1234)
|
| 136 |
for backend in ['torch', 'xformers']:
|
|
|
|
| 137 |
set_efficient_attention_backend(backend)
|
| 138 |
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
|
| 139 |
tr_mem_efficient = StreamingTransformer(
|
|
|
|
| 132 |
|
| 133 |
@torch.no_grad()
|
| 134 |
def test_streaming_memory_efficient():
|
|
|
|
| 135 |
for backend in ['torch', 'xformers']:
|
| 136 |
+
torch.manual_seed(1234)
|
| 137 |
set_efficient_attention_backend(backend)
|
| 138 |
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
|
| 139 |
tr_mem_efficient = StreamingTransformer(
|