drbh
commited on
Commit
·
c743a32
1
Parent(s):
4762963
fix: align kernel source with latest reference source
Browse files- .gitignore +8 -1
- build.toml +6 -20
- flash_attn/flash_api.cpp +5 -5
- flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +0 -14
- flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +0 -14
- flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +0 -14
- flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +0 -14
- flash_attn/src/flash_bwd_launch_template.h +1 -21
- flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +0 -14
- flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +0 -14
- flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +0 -14
- flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +0 -14
- flash_attn/src/flash_fwd_launch_template.h +2 -31
- flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu +0 -11
- flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +0 -11
- flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu +0 -11
- flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +0 -11
- flash_attn/src/generate_kernels.py +1 -1
- flash_attn/src/static_switch.h +0 -3
.gitignore
CHANGED
@@ -1 +1,8 @@
|
|
1 |
-
.bak
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.bak
|
2 |
+
__pycache__
|
3 |
+
build-ext
|
4 |
+
cmake
|
5 |
+
result
|
6 |
+
CMakeLists.txt
|
7 |
+
setup.py
|
8 |
+
pyproject.toml
|
build.toml
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
[general]
|
2 |
name = "flash_attn"
|
|
|
3 |
|
4 |
[torch]
|
5 |
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
|
6 |
|
7 |
[kernel.flash_attn]
|
|
|
8 |
cuda-capabilities = [
|
9 |
"8.0",
|
10 |
"9.0",
|
@@ -13,6 +15,7 @@ cuda-capabilities = [
|
|
13 |
]
|
14 |
src = [
|
15 |
"flash_attn/flash_api.cpp",
|
|
|
16 |
"flash_attn/src/philox_unpack.cuh",
|
17 |
"flash_attn/src/namespace_config.h",
|
18 |
"flash_attn/src/hardware_info.h",
|
@@ -21,29 +24,18 @@ src = [
|
|
21 |
"flash_attn/src/alibi.h",
|
22 |
"flash_attn/src/block_info.h",
|
23 |
"flash_attn/src/dropout.h",
|
24 |
-
"flash_attn/src/flash.h",
|
25 |
-
"flash_attn/src/generate_kernels.py",
|
26 |
-
"flash_attn/src/hardware_info.h",
|
27 |
"flash_attn/src/kernel_traits.h",
|
28 |
"flash_attn/src/mask.h",
|
29 |
-
"flash_attn/src/namespace_config.h",
|
30 |
"flash_attn/src/philox.cuh",
|
31 |
-
"flash_attn/src/philox_unpack.cuh",
|
32 |
"flash_attn/src/rotary.h",
|
33 |
"flash_attn/src/softmax.h",
|
34 |
-
"flash_attn/src/static_switch.h",
|
35 |
"flash_attn/src/utils.h",
|
36 |
|
37 |
-
|
38 |
-
|
39 |
"flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
40 |
"flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
41 |
"flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
42 |
"flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
43 |
-
"flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
|
44 |
-
"flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
|
45 |
-
"flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
|
46 |
-
"flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
|
47 |
"flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
|
48 |
"flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
49 |
"flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
|
@@ -73,10 +65,6 @@ src = [
|
|
73 |
"flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
74 |
"flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
75 |
"flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
76 |
-
"flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
|
77 |
-
"flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
78 |
-
"flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
|
79 |
-
"flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
80 |
"flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
81 |
"flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
82 |
"flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
@@ -99,14 +87,12 @@ src = [
|
|
99 |
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
100 |
"flash_attn/src/flash_fwd_kernel.h",
|
101 |
"flash_attn/src/flash_fwd_launch_template.h",
|
|
|
|
|
102 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
103 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
104 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
105 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
106 |
-
"flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
|
107 |
-
"flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
|
108 |
-
"flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
|
109 |
-
"flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
|
110 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
111 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
112 |
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
|
|
1 |
[general]
|
2 |
name = "flash_attn"
|
3 |
+
universal=false
|
4 |
|
5 |
[torch]
|
6 |
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
|
7 |
|
8 |
[kernel.flash_attn]
|
9 |
+
backend = "cuda"
|
10 |
cuda-capabilities = [
|
11 |
"8.0",
|
12 |
"9.0",
|
|
|
15 |
]
|
16 |
src = [
|
17 |
"flash_attn/flash_api.cpp",
|
18 |
+
|
19 |
"flash_attn/src/philox_unpack.cuh",
|
20 |
"flash_attn/src/namespace_config.h",
|
21 |
"flash_attn/src/hardware_info.h",
|
|
|
24 |
"flash_attn/src/alibi.h",
|
25 |
"flash_attn/src/block_info.h",
|
26 |
"flash_attn/src/dropout.h",
|
|
|
|
|
|
|
27 |
"flash_attn/src/kernel_traits.h",
|
28 |
"flash_attn/src/mask.h",
|
|
|
29 |
"flash_attn/src/philox.cuh",
|
|
|
30 |
"flash_attn/src/rotary.h",
|
31 |
"flash_attn/src/softmax.h",
|
|
|
32 |
"flash_attn/src/utils.h",
|
33 |
|
34 |
+
# bwd kernels
|
|
|
35 |
"flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
36 |
"flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
37 |
"flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
38 |
"flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
39 |
"flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
|
40 |
"flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
41 |
"flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
|
|
|
65 |
"flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
66 |
"flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
67 |
"flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
68 |
"flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
69 |
"flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
70 |
"flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
|
|
87 |
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
88 |
"flash_attn/src/flash_fwd_kernel.h",
|
89 |
"flash_attn/src/flash_fwd_launch_template.h",
|
90 |
+
|
91 |
+
# split kernels
|
92 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
93 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
94 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
95 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
96 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
97 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
98 |
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
flash_attn/flash_api.cpp
CHANGED
@@ -432,7 +432,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
|
|
432 |
}
|
433 |
|
434 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
435 |
-
const int head_size_rounded = head_size <=
|
436 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
437 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
438 |
|
@@ -644,7 +644,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
644 |
}
|
645 |
|
646 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
647 |
-
const int head_size_rounded = head_size <=
|
648 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
649 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
650 |
|
@@ -831,7 +831,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
|
|
831 |
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
832 |
|
833 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
834 |
-
const int head_size_rounded = head_size <=
|
835 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
836 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
837 |
|
@@ -1048,7 +1048,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
1048 |
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
1049 |
|
1050 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
1051 |
-
const int head_size_rounded = head_size <=
|
1052 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
1053 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
1054 |
|
@@ -1321,7 +1321,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
1321 |
|
1322 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
1323 |
const int head_size = round_multiple(head_size_og, 8);
|
1324 |
-
const int head_size_rounded = head_size <=
|
1325 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
1326 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
1327 |
|
|
|
432 |
}
|
433 |
|
434 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
435 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
436 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
437 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
438 |
|
|
|
644 |
}
|
645 |
|
646 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
647 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
648 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
649 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
650 |
|
|
|
831 |
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
832 |
|
833 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
834 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
835 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
836 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
837 |
|
|
|
1048 |
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
1049 |
|
1050 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
1051 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
1052 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
1053 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
1054 |
|
|
|
1321 |
|
1322 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
1323 |
const int head_size = round_multiple(head_size_og, 8);
|
1324 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
1325 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
1326 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
1327 |
|
flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_bwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_bwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_bwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_bwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_launch_template.h
CHANGED
@@ -102,7 +102,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
|
|
102 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
103 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
104 |
// If Is_local, set Is_causal to false
|
105 |
-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
|
106 |
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
107 |
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
108 |
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
@@ -261,26 +261,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
261 |
});
|
262 |
}
|
263 |
|
264 |
-
template<typename T, bool Is_causal>
|
265 |
-
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
266 |
-
constexpr static int Headdim = 160;
|
267 |
-
int device;
|
268 |
-
cudaGetDevice(&device);
|
269 |
-
int max_smem_per_block;
|
270 |
-
cudaError status_ = cudaDeviceGetAttribute(
|
271 |
-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
272 |
-
if (status_ != cudaSuccess) {
|
273 |
-
C10_CUDA_CHECK(status_);
|
274 |
-
}
|
275 |
-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
276 |
-
if (max_smem_per_block >= 116 * 1024) {
|
277 |
-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
278 |
-
} else {
|
279 |
-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
280 |
-
}
|
281 |
-
});
|
282 |
-
}
|
283 |
-
|
284 |
template<typename T, bool Is_causal>
|
285 |
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
286 |
constexpr static int Headdim = 192;
|
|
|
102 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
103 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
104 |
// If Is_local, set Is_causal to false
|
105 |
+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>;
|
106 |
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
107 |
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
108 |
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
|
261 |
});
|
262 |
}
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
template<typename T, bool Is_causal>
|
265 |
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
266 |
constexpr static int Headdim = 192;
|
flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template<>
|
10 |
-
void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
11 |
-
run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
|
12 |
-
}
|
13 |
-
|
14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_launch_template.h
CHANGED
@@ -76,7 +76,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
76 |
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
77 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
78 |
// If Is_local, set Is_causal to false
|
79 |
-
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
80 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
81 |
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
82 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
@@ -117,7 +117,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
117 |
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
118 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
119 |
// If Is_local, set Is_causal to false
|
120 |
-
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
|
121 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
122 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
123 |
if (smem_size >= 48 * 1024) {
|
@@ -165,7 +165,6 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
|
|
165 |
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
166 |
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
167 |
// and for headdim 192 with block size 64 x 128.
|
168 |
-
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
169 |
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
170 |
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
|
171 |
}
|
@@ -257,34 +256,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
257 |
});
|
258 |
}
|
259 |
|
260 |
-
template<typename T, bool Is_causal>
|
261 |
-
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
262 |
-
constexpr static int Headdim = 160;
|
263 |
-
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
264 |
-
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
265 |
-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
266 |
-
// For A100, H100, 128 x 32 is the fastest.
|
267 |
-
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
268 |
-
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
269 |
-
if (is_sm8x) {
|
270 |
-
if constexpr(!Is_causal) {
|
271 |
-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
272 |
-
} else {
|
273 |
-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
274 |
-
}
|
275 |
-
} else {
|
276 |
-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
277 |
-
}
|
278 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
279 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
280 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
281 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
282 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
283 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
284 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
285 |
-
});
|
286 |
-
}
|
287 |
-
|
288 |
template<typename T, bool Is_causal>
|
289 |
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
290 |
constexpr static int Headdim = 192;
|
|
|
76 |
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
77 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
78 |
// If Is_local, set Is_causal to false
|
79 |
+
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst && !Has_alibi, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
80 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
81 |
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
82 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
|
|
117 |
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
118 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
119 |
// If Is_local, set Is_causal to false
|
120 |
+
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap, Split, Append_KV>;
|
121 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
122 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
123 |
if (smem_size >= 48 * 1024) {
|
|
|
165 |
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
166 |
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
167 |
// and for headdim 192 with block size 64 x 128.
|
|
|
168 |
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
169 |
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
|
170 |
}
|
|
|
256 |
});
|
257 |
}
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
template<typename T, bool Is_causal>
|
260 |
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
261 |
constexpr static int Headdim = 192;
|
flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
10 |
-
|
11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
10 |
-
|
11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
10 |
-
|
11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
// Copyright (c) 2024, Tri Dao.
|
2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
4 |
-
#include "namespace_config.h"
|
5 |
-
#include "flash_fwd_launch_template.h"
|
6 |
-
|
7 |
-
namespace FLASH_NAMESPACE {
|
8 |
-
|
9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
10 |
-
|
11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/generate_kernels.py
CHANGED
@@ -10,7 +10,7 @@ DTYPE_MAP = {
|
|
10 |
}
|
11 |
|
12 |
SM = [80] # Sm80 kernels support up to
|
13 |
-
HEAD_DIMENSIONS = [32, 64, 96, 128,
|
14 |
IS_CAUSAL = ["false", "true"]
|
15 |
NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
|
16 |
|
|
|
10 |
}
|
11 |
|
12 |
SM = [80] # Sm80 kernels support up to
|
13 |
+
HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256]
|
14 |
IS_CAUSAL = ["false", "true"]
|
15 |
NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
|
16 |
|
flash_attn/src/static_switch.h
CHANGED
@@ -101,9 +101,6 @@
|
|
101 |
} else if (HEADDIM <= 128) { \
|
102 |
constexpr static int kHeadDim = 128; \
|
103 |
return __VA_ARGS__(); \
|
104 |
-
} else if (HEADDIM <= 160) { \
|
105 |
-
constexpr static int kHeadDim = 160; \
|
106 |
-
return __VA_ARGS__(); \
|
107 |
} else if (HEADDIM <= 192) { \
|
108 |
constexpr static int kHeadDim = 192; \
|
109 |
return __VA_ARGS__(); \
|
|
|
101 |
} else if (HEADDIM <= 128) { \
|
102 |
constexpr static int kHeadDim = 128; \
|
103 |
return __VA_ARGS__(); \
|
|
|
|
|
|
|
104 |
} else if (HEADDIM <= 192) { \
|
105 |
constexpr static int kHeadDim = 192; \
|
106 |
return __VA_ARGS__(); \
|