flash attention 2
Browse files- docker/Dockerfile-base +1 -1
- src/axolotl/flash_attn.py +3 -3
docker/Dockerfile-base
CHANGED
|
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|
| 40 |
|
| 41 |
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
| 42 |
cd flash-attention && \
|
| 43 |
-
git checkout
|
| 44 |
python3 setup.py bdist_wheel && \
|
| 45 |
cd csrc/fused_dense_lib && \
|
| 46 |
python3 setup.py bdist_wheel && \
|
|
|
|
| 40 |
|
| 41 |
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
| 42 |
cd flash-attention && \
|
| 43 |
+
git checkout v2.0.0 && \
|
| 44 |
python3 setup.py bdist_wheel && \
|
| 45 |
cd csrc/fused_dense_lib && \
|
| 46 |
python3 setup.py bdist_wheel && \
|
src/axolotl/flash_attn.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch
|
|
| 8 |
import transformers
|
| 9 |
from einops import rearrange
|
| 10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 11 |
-
from flash_attn.flash_attn_interface import
|
| 12 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
| 13 |
|
| 14 |
|
|
@@ -79,7 +79,7 @@ def forward(
|
|
| 79 |
dtype=torch.int32,
|
| 80 |
device=qkv.device,
|
| 81 |
)
|
| 82 |
-
output =
|
| 83 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 84 |
)
|
| 85 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
@@ -95,7 +95,7 @@ def forward(
|
|
| 95 |
three=3,
|
| 96 |
h=nheads,
|
| 97 |
)
|
| 98 |
-
output_unpad =
|
| 99 |
x_unpad,
|
| 100 |
cu_q_lens,
|
| 101 |
max_s,
|
|
|
|
| 8 |
import transformers
|
| 9 |
from einops import rearrange
|
| 10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
| 11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 12 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
| 13 |
|
| 14 |
|
|
|
|
| 79 |
dtype=torch.int32,
|
| 80 |
device=qkv.device,
|
| 81 |
)
|
| 82 |
+
output = flash_attn_varlen_qkvpacked_func(
|
| 83 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 84 |
)
|
| 85 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
|
|
| 95 |
three=3,
|
| 96 |
h=nheads,
|
| 97 |
)
|
| 98 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 99 |
x_unpad,
|
| 100 |
cu_q_lens,
|
| 101 |
max_s,
|