Lint flash_attn.py
Browse files
src/axolotl/flash_attn.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
| 2 |
|
| 3 |
-
from typing import
|
| 4 |
|
| 5 |
import torch
|
| 6 |
-
from torch import nn
|
| 7 |
|
| 8 |
import transformers
|
| 9 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
@@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
|
| 14 |
from flash_attn.bert_padding import unpad_input, pad_input
|
| 15 |
|
| 16 |
|
| 17 |
-
def forward(
|
| 18 |
self,
|
| 19 |
hidden_states: torch.Tensor,
|
| 20 |
attention_mask: Optional[torch.Tensor] = None,
|
|
@@ -82,6 +83,8 @@ def forward(
|
|
| 82 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 83 |
else:
|
| 84 |
nheads = qkv.shape[-2]
|
|
|
|
|
|
|
| 85 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 86 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
| 87 |
x_unpad = rearrange(
|
|
@@ -104,13 +107,13 @@ def forward(
|
|
| 104 |
# requires the attention mask to be the same as the key_padding_mask
|
| 105 |
def _prepare_decoder_attention_mask(
|
| 106 |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 107 |
-
):
|
| 108 |
# [bsz, seq_len]
|
| 109 |
return attention_mask
|
| 110 |
|
| 111 |
|
| 112 |
def replace_llama_attn_with_flash_attn():
|
| 113 |
-
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
| 114 |
_prepare_decoder_attention_mask
|
| 115 |
)
|
| 116 |
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
|
|
| 1 |
+
"""Flash attention monkey patch for llama model"""
|
| 2 |
+
|
| 3 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
| 4 |
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 8 |
|
| 9 |
import transformers
|
| 10 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
|
|
| 15 |
from flash_attn.bert_padding import unpad_input, pad_input
|
| 16 |
|
| 17 |
|
| 18 |
+
def forward( # pylint: disable=too-many-arguments
|
| 19 |
self,
|
| 20 |
hidden_states: torch.Tensor,
|
| 21 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 83 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 84 |
else:
|
| 85 |
nheads = qkv.shape[-2]
|
| 86 |
+
|
| 87 |
+
# pylint: disable=invalid-name
|
| 88 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 89 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
| 90 |
x_unpad = rearrange(
|
|
|
|
| 107 |
# requires the attention mask to be the same as the key_padding_mask
|
| 108 |
def _prepare_decoder_attention_mask(
|
| 109 |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 110 |
+
): # pylint: disable=unused-argument
|
| 111 |
# [bsz, seq_len]
|
| 112 |
return attention_mask
|
| 113 |
|
| 114 |
|
| 115 |
def replace_llama_attn_with_flash_attn():
|
| 116 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
| 117 |
_prepare_decoder_attention_mask
|
| 118 |
)
|
| 119 |
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|