skip some flash attn patches unless explicitly enabled (#643)
Browse files* skip some flash attn patches if explicitly disabled
* make the other patches optional
- README.md +2 -0
- src/axolotl/monkeypatch/llama_attn_hijack_flash.py +31 -23
- src/axolotl/utils/models.py +5 -1
README.md
CHANGED
|
@@ -636,6 +636,8 @@ flash_optimum:
|
|
| 636 |
xformers_attention:
|
| 637 |
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
| 638 |
flash_attention:
|
|
|
|
|
|
|
| 639 |
# whether to use scaled-dot-product attention
|
| 640 |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
| 641 |
sdp_attention:
|
|
|
|
| 636 |
xformers_attention:
|
| 637 |
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
| 638 |
flash_attention:
|
| 639 |
+
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
| 640 |
+
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
| 641 |
# whether to use scaled-dot-product attention
|
| 642 |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
| 643 |
sdp_attention:
|
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -38,7 +38,11 @@ except ImportError:
|
|
| 38 |
LOG = logging.getLogger("axolotl")
|
| 39 |
|
| 40 |
|
| 41 |
-
def replace_llama_attn_with_flash_attn(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
| 43 |
_prepare_decoder_attention_mask
|
| 44 |
)
|
|
@@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|
| 49 |
llama_model_forward
|
| 50 |
)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
|
| 81 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
|
|
|
| 38 |
LOG = logging.getLogger("axolotl")
|
| 39 |
|
| 40 |
|
| 41 |
+
def replace_llama_attn_with_flash_attn(
|
| 42 |
+
packed: Optional[bool] = False,
|
| 43 |
+
cross_entropy: Optional[bool] = False,
|
| 44 |
+
rms_norm: Optional[bool] = False,
|
| 45 |
+
):
|
| 46 |
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
| 47 |
_prepare_decoder_attention_mask
|
| 48 |
)
|
|
|
|
| 53 |
llama_model_forward
|
| 54 |
)
|
| 55 |
|
| 56 |
+
# skip only if explicitly disabled
|
| 57 |
+
if cross_entropy:
|
| 58 |
+
try:
|
| 59 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 60 |
|
| 61 |
+
LOG.info("patching with flash_attn.losses.cross_entropy")
|
| 62 |
+
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
| 63 |
+
CrossEntropyLoss, inplace_backward=True
|
| 64 |
+
)
|
| 65 |
+
except ImportError:
|
| 66 |
+
LOG.info(
|
| 67 |
+
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
| 68 |
+
)
|
| 69 |
|
| 70 |
+
# skip only if explicitly disabled
|
| 71 |
+
if rms_norm:
|
| 72 |
+
try:
|
| 73 |
+
from flash_attn.ops.rms_norm import RMSNorm
|
| 74 |
|
| 75 |
+
class LlamaRMSNorm(RMSNorm):
|
| 76 |
+
"""Patched LLamaRMSNorm"""
|
| 77 |
|
| 78 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 79 |
+
super().__init__(hidden_size, eps=eps)
|
| 80 |
|
| 81 |
+
LOG.info("patching with flash_attn.ops.rms_norm")
|
| 82 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
| 83 |
+
except ImportError:
|
| 84 |
+
LOG.info(
|
| 85 |
+
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
| 86 |
+
)
|
| 87 |
|
| 88 |
|
| 89 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
src/axolotl/utils/models.py
CHANGED
|
@@ -121,7 +121,11 @@ def load_model(
|
|
| 121 |
)
|
| 122 |
|
| 123 |
LOG.info("patching with flash attention for sample packing")
|
| 124 |
-
replace_llama_attn_with_flash_attn(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
| 126 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
| 127 |
hijack_llama_attention,
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
LOG.info("patching with flash attention for sample packing")
|
| 124 |
+
replace_llama_attn_with_flash_attn(
|
| 125 |
+
packed=cfg.sample_packing,
|
| 126 |
+
cross_entropy=cfg.flash_attn_cross_entropy,
|
| 127 |
+
rms_norm=cfg.flash_attn_rms_norm,
|
| 128 |
+
)
|
| 129 |
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
| 130 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
| 131 |
hijack_llama_attention,
|