fix sdp attention to use the flash/mem-efficient context manaager
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
|
@@ -184,14 +184,15 @@ def sdp_attention_forward(
|
|
| 184 |
|
| 185 |
# We only apply sdp attention if we don't need to output the whole attention matrix
|
| 186 |
if not output_attentions:
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
| 195 |
else:
|
| 196 |
attn_weights = torch.matmul(
|
| 197 |
query_states, key_states.transpose(2, 3)
|
|
|
|
| 184 |
|
| 185 |
# We only apply sdp attention if we don't need to output the whole attention matrix
|
| 186 |
if not output_attentions:
|
| 187 |
+
with torch.backends.cuda.sdp_kernel():
|
| 188 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 189 |
+
query_states,
|
| 190 |
+
key_states,
|
| 191 |
+
value_states,
|
| 192 |
+
attn_mask=attention_mask,
|
| 193 |
+
is_causal=False,
|
| 194 |
+
)
|
| 195 |
+
attn_weights = None
|
| 196 |
else:
|
| 197 |
attn_weights = torch.matmul(
|
| 198 |
query_states, key_states.transpose(2, 3)
|