kernel
drbh commited on
Commit
c743a32
·
1 Parent(s): 4762963

fix: align kernel source with latest reference source

Browse files
.gitignore CHANGED
@@ -1 +1,8 @@
1
- .bak
 
 
 
 
 
 
 
 
1
+ .bak
2
+ __pycache__
3
+ build-ext
4
+ cmake
5
+ result
6
+ CMakeLists.txt
7
+ setup.py
8
+ pyproject.toml
build.toml CHANGED
@@ -1,10 +1,12 @@
1
  [general]
2
  name = "flash_attn"
 
3
 
4
  [torch]
5
  src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
6
 
7
  [kernel.flash_attn]
 
8
  cuda-capabilities = [
9
  "8.0",
10
  "9.0",
@@ -13,6 +15,7 @@ cuda-capabilities = [
13
  ]
14
  src = [
15
  "flash_attn/flash_api.cpp",
 
16
  "flash_attn/src/philox_unpack.cuh",
17
  "flash_attn/src/namespace_config.h",
18
  "flash_attn/src/hardware_info.h",
@@ -21,29 +24,18 @@ src = [
21
  "flash_attn/src/alibi.h",
22
  "flash_attn/src/block_info.h",
23
  "flash_attn/src/dropout.h",
24
- "flash_attn/src/flash.h",
25
- "flash_attn/src/generate_kernels.py",
26
- "flash_attn/src/hardware_info.h",
27
  "flash_attn/src/kernel_traits.h",
28
  "flash_attn/src/mask.h",
29
- "flash_attn/src/namespace_config.h",
30
  "flash_attn/src/philox.cuh",
31
- "flash_attn/src/philox_unpack.cuh",
32
  "flash_attn/src/rotary.h",
33
  "flash_attn/src/softmax.h",
34
- "flash_attn/src/static_switch.h",
35
  "flash_attn/src/utils.h",
36
 
37
- ## bwd kernels
38
-
39
  "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
40
  "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
41
  "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
42
  "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
43
- "flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
44
- "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
45
- "flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
46
- "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
47
  "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
48
  "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
49
  "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
@@ -73,10 +65,6 @@ src = [
73
  "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
74
  "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
75
  "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
76
- "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
77
- "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
78
- "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
79
- "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
80
  "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
81
  "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
82
  "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
@@ -99,14 +87,12 @@ src = [
99
  "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
100
  "flash_attn/src/flash_fwd_kernel.h",
101
  "flash_attn/src/flash_fwd_launch_template.h",
 
 
102
  "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
103
  "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
104
  "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
105
  "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
106
- "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
107
- "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
108
- "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
109
- "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
110
  "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
111
  "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
112
  "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
 
1
  [general]
2
  name = "flash_attn"
3
+ universal=false
4
 
5
  [torch]
6
  src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
7
 
8
  [kernel.flash_attn]
9
+ backend = "cuda"
10
  cuda-capabilities = [
11
  "8.0",
12
  "9.0",
 
15
  ]
16
  src = [
17
  "flash_attn/flash_api.cpp",
18
+
19
  "flash_attn/src/philox_unpack.cuh",
20
  "flash_attn/src/namespace_config.h",
21
  "flash_attn/src/hardware_info.h",
 
24
  "flash_attn/src/alibi.h",
25
  "flash_attn/src/block_info.h",
26
  "flash_attn/src/dropout.h",
 
 
 
27
  "flash_attn/src/kernel_traits.h",
28
  "flash_attn/src/mask.h",
 
29
  "flash_attn/src/philox.cuh",
 
30
  "flash_attn/src/rotary.h",
31
  "flash_attn/src/softmax.h",
 
32
  "flash_attn/src/utils.h",
33
 
34
+ # bwd kernels
 
35
  "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
36
  "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
37
  "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
38
  "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
 
 
 
 
39
  "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
40
  "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
41
  "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
 
65
  "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
66
  "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
67
  "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
 
 
 
 
68
  "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
69
  "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
70
  "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
 
87
  "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
88
  "flash_attn/src/flash_fwd_kernel.h",
89
  "flash_attn/src/flash_fwd_launch_template.h",
90
+
91
+ # split kernels
92
  "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
93
  "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
94
  "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
95
  "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
 
 
 
 
96
  "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
97
  "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
98
  "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
flash_attn/flash_api.cpp CHANGED
@@ -432,7 +432,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
432
  }
433
 
434
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
435
- const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
436
  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
437
  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
438
 
@@ -644,7 +644,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
644
  }
645
 
646
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
647
- const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
648
  const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
649
  const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
650
 
@@ -831,7 +831,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
831
  TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
832
 
833
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
834
- const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
835
  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
836
  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
837
 
@@ -1048,7 +1048,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
1048
  if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
1049
 
1050
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1051
- const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
1052
  const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
1053
  const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
1054
 
@@ -1321,7 +1321,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
1321
 
1322
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1323
  const int head_size = round_multiple(head_size_og, 8);
1324
- const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
1325
  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
1326
  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
1327
 
 
432
  }
433
 
434
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
435
+ const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
436
  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
437
  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
438
 
 
644
  }
645
 
646
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
647
+ const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
648
  const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
649
  const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
650
 
 
831
  TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
832
 
833
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
834
+ const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
835
  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
836
  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
837
 
 
1048
  if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
1049
 
1050
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1051
+ const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
1052
  const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
1053
  const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
1054
 
 
1321
 
1322
  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1323
  const int head_size = round_multiple(head_size_og, 8);
1324
+ const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
1325
  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
1326
  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
1327
 
flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_bwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_bwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_bwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_bwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_bwd_launch_template.h CHANGED
@@ -102,7 +102,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
102
  // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
103
  // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
104
  // If Is_local, set Is_causal to false
105
- auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
106
  // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
107
  if (smem_size_dq_dk_dv >= 48 * 1024) {
108
  C10_CUDA_CHECK(cudaFuncSetAttribute(
@@ -261,26 +261,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
261
  });
262
  }
263
 
264
- template<typename T, bool Is_causal>
265
- void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
266
- constexpr static int Headdim = 160;
267
- int device;
268
- cudaGetDevice(&device);
269
- int max_smem_per_block;
270
- cudaError status_ = cudaDeviceGetAttribute(
271
- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
272
- if (status_ != cudaSuccess) {
273
- C10_CUDA_CHECK(status_);
274
- }
275
- DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
276
- if (max_smem_per_block >= 116 * 1024) {
277
- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
278
- } else {
279
- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
280
- }
281
- });
282
- }
283
-
284
  template<typename T, bool Is_causal>
285
  void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
286
  constexpr static int Headdim = 192;
 
102
  // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
103
  // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
104
  // If Is_local, set Is_causal to false
105
+ auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>;
106
  // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
107
  if (smem_size_dq_dk_dv >= 48 * 1024) {
108
  C10_CUDA_CHECK(cudaFuncSetAttribute(
 
261
  });
262
  }
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  template<typename T, bool Is_causal>
265
  void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
266
  constexpr static int Headdim = 192;
flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
- run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
- run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
- run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu DELETED
@@ -1,14 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template<>
10
- void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
- run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
12
- }
13
-
14
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_launch_template.h CHANGED
@@ -76,7 +76,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
76
  // If return_softmax, set IsEvenMNConst to false to reduce number of templates
77
  // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
78
  // If Is_local, set Is_causal to false
79
- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
80
  // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
81
  // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
82
  // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
@@ -117,7 +117,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
117
  // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
118
  // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
119
  // If Is_local, set Is_causal to false
120
- auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
121
  // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
122
  // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
123
  if (smem_size >= 48 * 1024) {
@@ -165,7 +165,6 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
165
  constexpr static int kBlockM = 64; // Fixed for all head dimensions
166
  // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
167
  // and for headdim 192 with block size 64 x 128.
168
- // Also for headdim 160 with block size 64 x 128 after the rotary addition.
169
  constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
170
  run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
171
  }
@@ -257,34 +256,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
257
  });
258
  }
259
 
260
- template<typename T, bool Is_causal>
261
- void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
262
- constexpr static int Headdim = 160;
263
- auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
264
- bool is_sm8x = cc_major == 8 && cc_minor > 0;
265
- DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
266
- // For A100, H100, 128 x 32 is the fastest.
267
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
268
- // and 128 x 64 with 8 warps is the fastest for non-causal.
269
- if (is_sm8x) {
270
- if constexpr(!Is_causal) {
271
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
272
- } else {
273
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
274
- }
275
- } else {
276
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
277
- }
278
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
279
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
280
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
281
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
282
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
283
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
284
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
285
- });
286
- }
287
-
288
  template<typename T, bool Is_causal>
289
  void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
290
  constexpr static int Headdim = 192;
 
76
  // If return_softmax, set IsEvenMNConst to false to reduce number of templates
77
  // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
78
  // If Is_local, set Is_causal to false
79
+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst && !Has_alibi, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
80
  // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
81
  // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
82
  // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
 
117
  // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
118
  // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
119
  // If Is_local, set Is_causal to false
120
+ auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap, Split, Append_KV>;
121
  // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
122
  // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
123
  if (smem_size >= 48 * 1024) {
 
165
  constexpr static int kBlockM = 64; // Fixed for all head dimensions
166
  // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
167
  // and for headdim 192 with block size 64 x 128.
 
168
  constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
169
  run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
170
  }
 
256
  });
257
  }
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  template<typename T, bool Is_causal>
260
  void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
261
  constexpr static int Headdim = 192;
flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu DELETED
@@ -1,11 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);
10
-
11
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu DELETED
@@ -1,11 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);
10
-
11
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu DELETED
@@ -1,11 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);
10
-
11
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu DELETED
@@ -1,11 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
- #include "namespace_config.h"
5
- #include "flash_fwd_launch_template.h"
6
-
7
- namespace FLASH_NAMESPACE {
8
-
9
- template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);
10
-
11
- } // namespace FLASH_NAMESPACE
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/src/generate_kernels.py CHANGED
@@ -10,7 +10,7 @@ DTYPE_MAP = {
10
  }
11
 
12
  SM = [80] # Sm80 kernels support up to
13
- HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256]
14
  IS_CAUSAL = ["false", "true"]
15
  NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
16
 
 
10
  }
11
 
12
  SM = [80] # Sm80 kernels support up to
13
+ HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256]
14
  IS_CAUSAL = ["false", "true"]
15
  NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
16
 
flash_attn/src/static_switch.h CHANGED
@@ -101,9 +101,6 @@
101
  } else if (HEADDIM <= 128) { \
102
  constexpr static int kHeadDim = 128; \
103
  return __VA_ARGS__(); \
104
- } else if (HEADDIM <= 160) { \
105
- constexpr static int kHeadDim = 160; \
106
- return __VA_ARGS__(); \
107
  } else if (HEADDIM <= 192) { \
108
  constexpr static int kHeadDim = 192; \
109
  return __VA_ARGS__(); \
 
101
  } else if (HEADDIM <= 128) { \
102
  constexpr static int kHeadDim = 128; \
103
  return __VA_ARGS__(); \
 
 
 
104
  } else if (HEADDIM <= 192) { \
105
  constexpr static int kHeadDim = 192; \
106
  return __VA_ARGS__(); \