diff --git a/README.md b/README.md index bc225aef63a8dd17f4d7d4b2d624921aa531f04c..ba4798a1a824db590d9225a54615b149092d50a8 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,14 @@ --- license: apache-2.0 tags: -- kernel + - kernel --- - + # vllm-flash-attn3 This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels). +Kernel source: https://github.com/huggingface/kernels-community/tree/main/vllm-flash-attn3 ## Quickstart @@ -43,7 +44,7 @@ torch.cuda.manual_seed(42) # Parameters batch_size = 2 seqlen_q = 128 # Query sequence length -seqlen_k = 256 # Key sequence length +seqlen_k = 256 # Key sequence length nheads = 8 # Number of attention heads d = 64 # Head dimension @@ -65,7 +66,6 @@ print(f"\nAttention computation successful!") print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}") ``` - ## How to Use When loading your model with transformers, provide this repository id as the source of the attention implementation: @@ -91,4 +91,5 @@ This will automatically resolve and download the appropriate code for your archi - [Tri Dao](https://huggingface.co/tridao) and team for Flash Attention and [Flash Attention 3](https://tridao.me/blog/2024/flash3/). - The [vLLM team](https://huggingface.co/vllm-project) for their implementation and their contribution of attention sinks. -- The [transformers team](https://huggingface.co/transformers-community) for packaging, testing, building and making it available for use with the [kernels library](https://github.com/huggingface/kernels). \ No newline at end of file +- The [transformers team](https://huggingface.co/transformers-community) for packaging, testing, building and making it available for use with the [kernels library](https://github.com/huggingface/kernels). + diff --git a/build.toml b/build.toml deleted file mode 100644 index a7ef76c9a9b6fd1185051afb43317bafc2cf113f..0000000000000000000000000000000000000000 --- a/build.toml +++ /dev/null @@ -1,593 +0,0 @@ -[general] -name = "vllm_flash_attn3" -universal = false -cuda-minver = "12.4" -cuda-maxver = "12.4" - -[torch] -src = [ - "torch-ext/pytorch_shim.h", - "torch-ext/torch_binding.cpp", - "torch-ext/torch_binding.h", -] - -[kernel.flash_attn] -backend = "cuda" -cuda-capabilities = ["8.0", "9.0a"] -cuda-flags = [ - "-O3", - "-std=c++17", - "--ftemplate-backtrace-limit=0", # To debug template code - "--use_fast_math", - "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", - "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", - "-DCUTLASS_ENABLE_GDC_FOR_SM90", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-DNDEBUG", -] -cxx-flags = ["-DFLASHATTENTION_DISABLE_PYBIND"] -src = [ - "flash-attn/cuda_check.h", - "flash-attn/flash_api.cpp", - "flash-attn/flash_fwd_combine.cu", - "flash-attn/flash_fwd_combine_kernel.h", - "flash-attn/flash_fwd_combine_launch_template.h", - "flash-attn/flash.h", - "flash-attn/flash_prepare_scheduler.cu", - "flash-attn/heuristics.h", - "flash-attn/seqlen.h", - "flash-attn/static_switch.h", - "flash-attn/tile_size.h", - "flash-attn/utils.h", -] -depends = ["torch", "cutlass_3_9"] - -[kernel.flash_attn_sm80] -backend = "cuda" -cuda-capabilities = ["8.0", "9.0a"] -cuda-flags = [ - "-O3", - "-std=c++17", - "--ftemplate-backtrace-limit=0", # To debug template code - "--use_fast_math", - "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", - "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", - "-DCUTLASS_ENABLE_GDC_FOR_SM90", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-DNDEBUG", -] -src = [ - "flash-attn/block.h", - "flash-attn/copy_sm90_bulk_reduce.hpp", - "flash-attn/epilogue_bwd.hpp", - "flash-attn/epilogue_fwd.hpp", - "flash-attn/flash.h", - "flash-attn/flash_bwd_kernel_sm80.h", - "flash-attn/flash_bwd_kernel_sm90.h", - "flash-attn/flash_bwd_launch_template.h", - "flash-attn/flash_bwd_postprocess_kernel.h", - "flash-attn/flash_bwd_preprocess_kernel.h", - "flash-attn/flash_fwd_launch_template.h", - "flash-attn/flash_fwd_kernel_sm80.h", - "flash-attn/flash_fwd_kernel_sm90.h", - "flash-attn/heuristics.h", - "flash-attn/mainloop_bwd_sm80.hpp", - "flash-attn/mainloop_fwd_sm80.hpp", - "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp", - "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp", - "flash-attn/mask.h", - "flash-attn/named_barrier.hpp", - "flash-attn/pack_gqa.h", - "flash-attn/paged_kv.h", - "flash-attn/rotary.h", - "flash-attn/sm90_pipeline_no_cluster.hpp", - "flash-attn/softmax.h", - "flash-attn/tile_size.h", - "flash-attn/tile_scheduler.hpp", - - "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu", - "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu" -] -include = ["flash-attn"] -depends = ["torch", "cutlass_3_9"] - -[kernel.flash_attn_sm90] -backend = "cuda" -cuda-capabilities = ["8.0", "9.0a"] -cuda-flags = [ - "-O3", - "-std=c++17", - "--ftemplate-backtrace-limit=0", # To debug template code - "--use_fast_math", - "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", - "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", - "-DCUTLASS_ENABLE_GDC_FOR_SM90", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-DNDEBUG", -] -src = [ - "flash-attn/block.h", - "flash-attn/copy_sm90_bulk_reduce.hpp", - "flash-attn/epilogue_bwd.hpp", - "flash-attn/epilogue_fwd.hpp", - "flash-attn/flash.h", - "flash-attn/flash_bwd_kernel_sm80.h", - "flash-attn/flash_bwd_kernel_sm90.h", - "flash-attn/flash_bwd_launch_template.h", - "flash-attn/flash_bwd_postprocess_kernel.h", - "flash-attn/flash_bwd_preprocess_kernel.h", - "flash-attn/flash_fwd_launch_template.h", - "flash-attn/flash_fwd_kernel_sm80.h", - "flash-attn/flash_fwd_kernel_sm90.h", - "flash-attn/heuristics.h", - "flash-attn/mainloop_bwd_sm80.hpp", - "flash-attn/mainloop_fwd_sm80.hpp", - "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp", - "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp", - "flash-attn/mask.h", - "flash-attn/named_barrier.hpp", - "flash-attn/pack_gqa.h", - "flash-attn/paged_kv.h", - "flash-attn/rotary.h", - "flash-attn/sm90_pipeline_no_cluster.hpp", - "flash-attn/softmax.h", - "flash-attn/tile_size.h", - "flash-attn/tile_scheduler.hpp", - - "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu", - "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu", - "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu", -] -include = ["flash-attn"] -depends = ["torch", "cutlass_3_9"] - -# [kernel.flash_attn_sm100] -# backend = "cuda" -# cuda-capabilities = ["8.0", "9.0a", "10.0"] -# cuda-flags = [ -# "-O3", -# "-std=c++17", -# "--ftemplate-backtrace-limit=0", # To debug template code -# "--use_fast_math", -# "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", -# "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", -# "-DCUTLASS_ENABLE_GDC_FOR_SM90", -# "--expt-relaxed-constexpr", -# "--expt-extended-lambda", -# "--use_fast_math", -# "-DNDEBUG", -# ] -# src = [ -# "flash-attn/block.h", -# "flash-attn/copy_sm90_bulk_reduce.hpp", -# "flash-attn/epilogue_bwd.hpp", -# "flash-attn/epilogue_fwd.hpp", -# "flash-attn/flash.h", -# "flash-attn/flash_bwd_kernel_sm80.h", -# "flash-attn/flash_bwd_kernel_sm90.h", -# "flash-attn/flash_bwd_launch_template.h", -# "flash-attn/flash_bwd_postprocess_kernel.h", -# "flash-attn/flash_bwd_preprocess_kernel.h", -# "flash-attn/flash_fwd_launch_template.h", -# "flash-attn/flash_fwd_kernel_sm80.h", -# "flash-attn/flash_fwd_kernel_sm90.h", -# "flash-attn/heuristics.h", -# "flash-attn/mainloop_bwd_sm80.hpp", -# "flash-attn/mainloop_fwd_sm80.hpp", -# "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp", -# "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp", -# "flash-attn/mask.h", -# "flash-attn/named_barrier.hpp", -# "flash-attn/pack_gqa.h", -# "flash-attn/paged_kv.h", -# "flash-attn/rotary.h", -# "flash-attn/sm90_pipeline_no_cluster.hpp", -# "flash-attn/softmax.h", -# "flash-attn/tile_size.h", -# "flash-attn/tile_scheduler.hpp", -# -# "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu", -# ] -# include = ["flash-attn"] -# depends = ["torch", "cutlass_3_9"] diff --git a/flake.lock b/flake.lock deleted file mode 100644 index db19a7703d95f7c33290e9695ab1701a55611ce6..0000000000000000000000000000000000000000 --- a/flake.lock +++ /dev/null @@ -1,168 +0,0 @@ -{ - "nodes": { - "flake-compat": { - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "locked": { - "lastModified": 1733328505, - "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "hf-nix": { - "inputs": { - "flake-compat": "flake-compat_2", - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs" - }, - "locked": { - "lastModified": 1750234878, - "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=", - "owner": "huggingface", - "repo": "hf-nix", - "rev": "c7132f90763d756da3e77da62e01be0a4546dc57", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "hf-nix", - "type": "github" - } - }, - "kernel-builder": { - "inputs": { - "flake-compat": "flake-compat", - "flake-utils": "flake-utils", - "hf-nix": "hf-nix", - "nixpkgs": [ - "kernel-builder", - "hf-nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1751014803, - "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=", - "owner": "huggingface", - "repo": "kernel-builder", - "rev": "bbc4e712ff2046e217818e97de2201e2b996756e", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "kernel-builder", - "type": "github" - } - }, - "nixpkgs": { - "locked": { - "lastModified": 1747820358, - "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", - "owner": "danieldk", - "repo": "nixpkgs", - "rev": "d3c1681180717528068082103bf323147de6ab0b", - "type": "github" - }, - "original": { - "owner": "danieldk", - "ref": "cudatoolkit-12.9-kernel-builder", - "repo": "nixpkgs", - "type": "github" - } - }, - "root": { - "inputs": { - "kernel-builder": "kernel-builder" - } - }, - "systems": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - } - }, - "root": "root", - "version": 7 -} diff --git a/flake.nix b/flake.nix deleted file mode 100644 index e9486be32f210bc91d182010cc23a2dd817d1b03..0000000000000000000000000000000000000000 --- a/flake.nix +++ /dev/null @@ -1,51 +0,0 @@ -{ - description = "Flake for Hopper Flash Attention kernel"; - - inputs = { - kernel-builder.url = "github:huggingface/kernel-builder"; - }; - - outputs = - { - self, - kernel-builder, - }: - kernel-builder.lib.genFlakeOutputs { - path = ./.; - rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; - # Building with CDUA later than 12.4 fails with: - # - # error: 'ptxas' died due to signal 11 (Invalid memory reference) - # - # So, build for 12.4 only and copy to all the other build variants - # by hand (which works fine thanks to backward compat). - # - # Still need to check if upstream FA3 has the same issue. - torchVersions = [ - { - torchVersion = "2.6"; - cudaVersion = "12.4"; - cxx11Abi = false; - systems = [ "x86_64-linux" ]; - upstreamVariant = true; - } - { - torchVersion = "2.6"; - cudaVersion = "12.4"; - cxx11Abi = true; - systems = [ "x86_64-linux" ]; - upstreamVariant = true; - } - { - torchVersion = "2.7"; - cudaVersion = "12.4"; - cxx11Abi = true; - systems = [ - "x86_64-linux" - "aarch64-linux" - ]; - upstreamVariant = true; - } - ]; - }; -} diff --git a/flash-attn/block.h b/flash-attn/block.h deleted file mode 100644 index eda7eaa1c403e09ded9266542aa7c58db4d601d5..0000000000000000000000000000000000000000 --- a/flash-attn/block.h +++ /dev/null @@ -1,94 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -namespace flash { - -template -struct BlockMN { - - static - CUTLASS_DEVICE - cute::tuple get_n_block_min_max( - SeqlenInfo_t const& seqlen_info, - int const m_block, int const bidb, int const split_idx, int const num_splits, - int const window_size_left, int const window_size_right, - cutlass::FastDivmod const& qhead_per_khead_divmod) { - - int const seqlen_k = seqlen_info.seqlen_k; - int const seqlen_q = seqlen_info.seqlen_q; - int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal || Is_local) { - int m_idx_max = (m_block + 1) * kBlockM; - // TODO: check off-by-1 error - if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } - n_block_max = std::min(n_block_max, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN)); - } - int n_block_min = 0; - if constexpr (Is_local) { - int m_idx_min = m_block * kBlockM; - if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } - n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN); - } - // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - if constexpr (Split) { - uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits - int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); - int split_idx_actual = split_idx & 0x0000FFFF; - int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); - n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; - n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); - // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } - } - // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - return {n_block_min, n_block_max}; - } - - static - CUTLASS_DEVICE - cute::tuple get_n_block_k_new_min_max( - SeqlenInfo_t const& seqlen_info, - int const m_block, int const bidb, int const split_idx, int const num_splits, - int const window_size_left, int const window_size_right, - cutlass::FastDivmod const& qhead_per_khead_divmod) { - - auto [n_block_min, n_block_max] = get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, num_splits, - window_size_left, window_size_right, qhead_per_khead_divmod); - int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); - int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); - int const n_block_new_min = idx_k_new_min / kBlockN; - int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; - // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} - return {n_block_new_min, n_block_new_max}; - } - - static - CUTLASS_DEVICE - cute::tuple get_m_block_min_max( - SeqlenInfo_t const& seqlen_info, - int const n_block, int const bidb, - int const window_size_left, int const window_size_right, int const sink_token_length) { - - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int m_block_max = cute::ceil_div(seqlen_q, kBlockM); - if constexpr (Is_local) { - if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM)); - } - } - int m_block_min = 0; - if constexpr (Is_causal || Is_local) { - m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM); - } - return {m_block_min, m_block_max}; - } - -}; - -} // namespace flash diff --git a/flash-attn/copy_sm90_bulk_reduce.hpp b/flash-attn/copy_sm90_bulk_reduce.hpp deleted file mode 100644 index 8556fae66d6d05af6c82bdc49735fdfece2a978b..0000000000000000000000000000000000000000 --- a/flash-attn/copy_sm90_bulk_reduce.hpp +++ /dev/null @@ -1,49 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -namespace cute -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SM90_BULK_REDUCE_ADD -{ - CUTE_HOST_DEVICE static void - copy(float const* smem_ptr, - float * gmem_ptr, int32_t store_bytes) - { -#if defined(CUTE_ARCH_TMA_SM90_ENABLED) - uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); - asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" - : - : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) - : "memory"); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); -#endif - } - - CUTE_HOST_DEVICE static void - copy(float const* smem_ptr, - float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint) - { -#if defined(CUTE_ARCH_TMA_SM90_ENABLED) - uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); - asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n" - : - : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint) - : "memory"); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // end namespace cute diff --git a/flash-attn/cuda_check.h b/flash-attn/cuda_check.h deleted file mode 100644 index b5e63aef79d22f9afdf83da05dbad0f2b8397ac9..0000000000000000000000000000000000000000 --- a/flash-attn/cuda_check.h +++ /dev/null @@ -1,19 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while(0) - -#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) diff --git a/flash-attn/epilogue_bwd.hpp b/flash-attn/epilogue_bwd.hpp deleted file mode 100644 index 9362b0404531ee7219eeeb0533d2843bdffb5723..0000000000000000000000000000000000000000 --- a/flash-attn/epilogue_bwd.hpp +++ /dev/null @@ -1,523 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/barrier.h" -#include "cute/tensor.hpp" - -#include "cutlass/gemm/collective/builders/sm90_common.inl" - -#include "seqlen.h" -#include "named_barrier.hpp" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct CollectiveEpilogueBwd { - - using TileShape_MNK = TileShape_MNK_; - using Element = Element_; - using ArchTag = ArchTag_; - static constexpr int NumEpilogueThreads = NumEpilogueThreads_; - static constexpr bool Varlen = Varlen_; - static constexpr bool dKV_swapAB = dKV_swapAB_; - static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90; - - static_assert(ArchTag::kMinComputeCapability >= 80); - - using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE; - - // These are for storing the output tensor without TMA (e.g., for setting output to zero) - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads); - static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - - using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int(TileShape_MNK{})) / AtomLayoutKdKV>>()); - using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{}))); - using SmemLayoutdKVtTMA = - decltype(cute::composition(SmemLayoutdKVTMA{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), - make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); - - // If we don't use TMA - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16); - static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); - using SmemLayoutAtomdKVSTG = - decltype(composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - - using SmemLayoutAtomdKV = std::conditional_t; - using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{}))); - using SmemLayoutdKVt = - decltype(cute::composition(SmemLayoutdKV{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), - make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); - - using SmemCopyAtomdKV = Copy_Atom< - std::conditional_t< - ArchTag::kMinComputeCapability >= 90, - std::conditional_t, - AutoVectorizingCopyWithAssumedAlignment<128> - >, - Element>; - - static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128; - static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment"); - - struct TensorStorage : cute::aligned_struct { - cute::array_aligned, SmemAlignmentdKV> smem_dk; - cute::array_aligned, SmemAlignmentdKV> smem_dv; - }; - - using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) - using StridedKV = cute::Stride; - - using TMA_dKV = std::conditional_t< - Use_TMA, - decltype(make_tma_copy( - GmemTiledCopydKVTMA{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedKV{}, StridedKV{}), - SmemLayoutdKVTMA{}, - select<1, 2>(TileShape_MNK{}), - _1{})), // no mcast for dKV - std::nullptr_t - >; - - // Host side kernel arguments - struct Arguments { - Element* ptr_dK; - ShapedKV const shape_dK; - StridedKV const stride_dK; - Element* ptr_dV; - StridedKV const stride_dV; - int const num_heads_q; - int* dk_semaphore; - int* dv_semaphore; - int const* cu_seqlens; - int const* seqused; - }; - - // Device side kernel params - struct Params { - Element* ptr_dK; - ShapedKV const shape_dK; - StridedKV const stride_dK; - Element* ptr_dV; - StridedKV const stride_dV; - TMA_dKV tma_store_dK, tma_store_dV; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); - Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV); - TMA_dKV tma_store_dK = [&] { - if constexpr (Use_TMA) { - return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV - } else { - return nullptr; - } - }(); - TMA_dKV tma_store_dV = [&] { - if constexpr (Use_TMA) { - return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV - } else { - return nullptr; - } - }(); - return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, - tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& params) { - if constexpr (Use_TMA) { - cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor()); - } - } - - template - CUTLASS_DEVICE void - store(Params const& params, - FrgTensorO const& tdKrdK, - FrgTensorO const& tdVrdV, - SharedStorage& shared_storage, - TiledMma tiled_mma, - int thread_idx, - cute::tuple const& block_coord - ) { - - auto [n_block, bidh, bidb] = block_coord; - Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{})); - Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{})); - Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{})); - Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{})); - auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx); - - Tensor tdVrdV_out = make_tensor_like(tdVrdV); - flash::convert_type_out(tdVrdV, tdVrdV_out); - Tensor tdKrdK_out = make_tensor_like(tdKrdK); - flash::convert_type_out(tdKrdK, tdKrdK_out); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); } - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // Make sure all WGs have finished reading K and V - flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); - if constexpr (Use_TMA) { - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - - Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); - Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK); - Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); - auto block_tma_dV = params.tma_store_dV.get_slice(_0{}); - Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) - Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) - Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) - Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) - int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); - if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { - cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - if (cute::elect_one_sync()) { - cute::copy(params.tma_store_dV, tdVsdV, tdVgdV); - cute::copy(params.tma_store_dK, tdKsdK, tdKgdK); - tma_store_arrive(); - } - } - tma_store_wait<0>(); - // // Tell warp 0 that smem_k and smem_v are ready - // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); - - } else { - flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; - bool const is_varlen = Varlen && params.cu_seqlens; - Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - - GmemTiledCopydKV gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); - Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K) - Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) - Tensor tdKVrdV = make_fragment_like(tdKVgdV); - Tensor tdKVrdK = make_fragment_like(tdKVgdK); - Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } - // Need to check OOB when reading from smem if kBlockN isn't evenly tiled - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; - flash::copy( - gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN); - flash::copy( - gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN); - // // Tell warp 0 that smem_k and smem_v are ready - // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v - // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); - // Construct identity layout for gdKV - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) - ); - flash::copy( - gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) - ); - } - } - - CUTLASS_DEVICE void - store_tail() { - // if constexpr (Use_TMA) { tma_store_wait<0>(); } - } - - // Write 0 to dK and dV - CUTLASS_DEVICE void - store_zero( - Params const& params, - int thread_idx, - cute::tuple const& block_coord - ) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - auto [n_block, bidh, bidb] = block_coord; - flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; - bool const is_varlen = Varlen && params.cu_seqlens; - Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - - GmemTiledCopydKV gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); - Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKVrdKV = make_fragment_like(tdKVgdK); - clear(tdKVrdKV); - // Construct identity layout for gdKV - Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN - ); - flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN - ); - } - -}; - -template -struct CollectiveEpilogueBwdGQA { - - using TileShape_MNK = TileShape_MNK_; - using Element = ElementAccum; - using ArchTag = ArchTag_; - static constexpr int NumEpilogueThreads = NumEpilogueThreads_; - static constexpr bool Varlen = Varlen_; - static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90; - - static_assert(ArchTag::kMinComputeCapability >= 80); - - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp"); - static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup; - // Thread layout, 256 or 384 threads per row - // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ - using R2SLayoutAtomdKVaccum = Layout, Int>>; - using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdKVaccum{}, - Layout>{})); // Val layout, 4 vals per store - // For Sm80 - using R2GLayoutAtomdKVaccum = Layout>>; - using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2GLayoutAtomdKVaccum{}, - Layout>{})); // Val layout, 1 vals per store - - using SmemLayoutdKVaccum = Layout, Int>>; - using SmemLayoutdKVaccumFlat = Layout>>; - - // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we - // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue. - static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256); - struct TensorStorageTMA : cute::aligned_struct { - cute::array_aligned, SmemAlignment> smem_dkv; - }; - struct TensorStorageSTG { - cute::array smem_dkv; - }; - using TensorStorage = std::conditional_t; - - using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) - using StridedKV = cute::Stride<_1, int64_t, int64_t>; - - // Host side kernel arguments - struct Arguments { - ElementAccum* ptr_dKaccum; - ShapedKV const shape_dKaccum; - StridedKV const stride_dKaccum; - ElementAccum* ptr_dVaccum; - StridedKV const stride_dVaccum; - int num_heads_q; - int* dk_semaphore; - int* dv_semaphore; - int const* cu_seqlens; - int const* seqused; - }; - - // Device side kernel params - struct Params { - ElementAccum* ptr_dKaccum; - ShapedKV const shape_dKaccum; - StridedKV const stride_dKaccum; - ElementAccum* ptr_dVaccum; - StridedKV const stride_dVaccum; - cutlass::FastDivmod qhead_per_khead_divmod; - int* dk_semaphore; - int* dv_semaphore; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - if constexpr (Deterministic) { - assert(args.dk_semaphore != nullptr); - assert(args.dv_semaphore != nullptr); - } - return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum, - cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), - args.dk_semaphore, args.dv_semaphore, - args.cu_seqlens, args.seqused}; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& params) { - } - - template - CUTLASS_DEVICE void - store(Params const& params, - FrgTensorO const& tdKrdK, - FrgTensorO const& tdVrdV, - SharedStorage& shared_storage, - TiledMma tiled_mma, - int thread_idx, - cute::tuple const& block_coord - ) { - - auto [n_block, bidh, bidb] = block_coord; - int bidh_idx_in_group; - int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); - Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{}); - Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{}); - static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum); - - flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; - bool const is_varlen = Varlen && params.cu_seqlens; - Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) - Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) - - R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; - auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); - Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV); - - // Only used if !Use_TMA - R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum; - auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx); - - // Make sure all WGs have finished reading K and V, otherwise we get racy dQ - // because smem_q could be changed. - flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - if constexpr (Use_TMA) { - Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); - } - - // int const num_batch = params.num_batch; - int const num_batch = get<2>(params.shape_dKaccum); - int const num_head_kv = get<1>(params.shape_dKaccum); - int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; - using Barrier = cutlass::GenericBarrier; - - // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} - - if constexpr (Deterministic) { - Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); - } - // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);} - if constexpr (Use_TMA) { - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - if (thread_idx == 0) { - SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); - tma_store_arrive(); - tma_store_wait<0>(); - } - } else { - Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV); - Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum); - static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic))); - #pragma unroll - for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); } - } - if constexpr (Deterministic) { - Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); - } - - if constexpr (Use_TMA) { - cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum); - } - lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv; - // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} - - if constexpr (Deterministic) { - Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); - } - // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);} - if constexpr (Use_TMA) { - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - if (thread_idx == 0) { - SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); - tma_store_arrive(); - tma_store_wait<0>(); - } - } else { - Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK); - Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum); - static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic))); - #pragma unroll - for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); } - } - if constexpr (Deterministic) { - Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); - } - // // Tell warp 0 that smem_k and smem_v are ready - // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); - } - - CUTLASS_DEVICE void - store_tail() { - } - - // Write 0 to dK and dV - CUTLASS_DEVICE void - store_zero( - Params const& params, - int thread_idx, - cute::tuple const& block_coord - ) { - // Don't need to do anything since dKaccum and dVaccum are already zero-initialized - } - -}; - -} // namespace flash diff --git a/flash-attn/epilogue_fwd.hpp b/flash-attn/epilogue_fwd.hpp deleted file mode 100644 index 69102e8c4e6e9fc144fd0fcf3d624890781afb84..0000000000000000000000000000000000000000 --- a/flash-attn/epilogue_fwd.hpp +++ /dev/null @@ -1,484 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include // For FastDivMod -#include "cute/tensor.hpp" - -#include "cutlass/gemm/collective/builders/sm90_common.inl" -#include "cutlass/epilogue/collective/builders/sm90_common.inl" - -#include "seqlen.h" -#include "named_barrier.hpp" -#include "pack_gqa.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct CollectiveEpilogueFwd { - - using TileShape_MNK_PV = TileShape_MNK_PV_; - using ClusterShape = ClusterShape_; - using Element = Element_; - using ElementPartial = float; - using ArchTag = ArchTag_; - static constexpr int NumEpilogueThreads = NumEpilogueThreads_; - static constexpr bool Varlen = Varlen_; - static constexpr bool PackGQA = PackGQA_; - static constexpr bool Split = Split_; - static constexpr bool Use_smem = !(Split && !Varlen); - static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; - - static_assert(ArchTag::kMinComputeCapability >= 80); - static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); - static_assert(sizeof(Element) <= 2); - - static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); - static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); - - static constexpr bool LargeHeadDimV = kHeadDimV > 256; - - using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; - - // These are for storing the output tensor without TMA (e.g., for setting output to zero) - static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements - // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times - // we need to call divmod. - static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; - // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp - static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); - static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - - using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); - using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); - static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); - static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); - using SmemLayoutAtomO = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); - using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; - - using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) - using StrideO = cute::Stride; - using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) - // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) - using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; - using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; - // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) - using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; - using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; - - using CopyOpR2S = std::conditional_t< - ArchTag::kMinComputeCapability >= 90, - // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) - decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), - AutoVectorizingCopyWithAssumedAlignment<128> - >; - using SmemCopyAtomO = Copy_Atom; - - // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{}); - // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment"); - // struct TensorStorage : cute::aligned_struct { - // cute::array_aligned : 0, SmemAlignmentO> smem_o; - // }; - struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned : 0> smem_o; - }; - - using TMA_O = std::conditional_t< - Use_TMA_O, - decltype(make_tma_copy( - GmemTiledCopyOTMA{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), - SmemLayoutOTMA{}, - select<0, 1>(TileShape_MNK_PV{}), - _1{})), // no mcast for O - std::nullptr_t - >; - - // Host side kernel arguments - struct Arguments { - Element* ptr_O; - ShapeO const shape_O; - StrideO const stride_O; - ElementPartial* ptr_O_partial; - StrideO const stride_O_partial; - float* ptr_LSE; - StrideLSE const stride_LSE; - float* ptr_LSE_partial; - StrideLSE const stride_LSE_partial; - int32_t const nheads_kv; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - // Device side kernel params - struct Params { - Element* ptr_O; - ShapeO const shape_O; - StrideO const stride_O; - ShapeOPacked const shape_O_packed; - StrideOPacked const stride_O_packed; - ElementPartial* ptr_O_partial; - StrideO const stride_O_partial; - StrideOPacked const stride_O_partial_packed; - float* ptr_LSE; - StrideLSE const stride_LSE; - ShapeLSEPacked const shape_LSE_packed; - StrideLSEPacked const stride_LSE_packed; - float* ptr_LSE_partial; - StrideLSE const stride_LSE_partial; - StrideLSEPacked const stride_LSE_partial_packed; - cutlass::FastDivmod qhead_per_khead_divmod; - TMA_O tma_store_O; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); - TMA_O tma_store_O = [&]{ - if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast - } else { - return nullptr; - } - }(); - // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) - int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); - auto const shape_O_packed = cute::conditional_return( - args.shape_O, - make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) - ); - auto const stride_O_packed = cute::conditional_return( - args.stride_O, - make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) - ); - auto const stride_O_partial_packed = cute::conditional_return( - args.stride_O_partial, - make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) - ); - // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) - auto const shape_LSE_packed = cute::conditional_return( - select<0, 2, 3, 4>(args.shape_O), - make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) - ); - auto const stride_LSE_packed = cute::conditional_return( - args.stride_LSE, - make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) - ); - auto const stride_LSE_partial_packed = cute::conditional_return( - args.stride_LSE_partial, - make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) - ); - return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, - args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, - args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, - args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, - cutlass::FastDivmod(qhead_per_khead), - tma_store_O, args.cu_seqlens, args.seqused}; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& params) { - if constexpr (Use_TMA_O) { - cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); - } - } - - template - CUTLASS_DEVICE void - store(Params const& params, - FrgTensorO& tOrO, - FrgTensorLSE const& lse, - SharedStorage& shared_storage, - TiledMma tiled_mma, - int thread_idx, - cute::tuple const& block_coord - ) { - - auto [m_block, bidh, bidb, split_idx] = block_coord; - int num_splits = get<4>(params.shape_O_packed); - if constexpr (Split && Varlen) { - uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits - int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); - num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; - split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx - } - bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); - - Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); - // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); - - static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); - // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. - // Otherwise we can permute after conversion. - if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } - Tensor tOrO_out = make_tensor_like(tOrO); - flash::convert_type_out(tOrO, tOrO_out); - if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } - - // Make sure all WGs have finished reading V - // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that - // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with - // cp.async if we need). - flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - - // Step 1: Write O from rmem -> smem - if constexpr (Use_smem) { - auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N) - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - if constexpr (Use_TMA_O) { - cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - } else { - flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - } - } else { - if constexpr (ArchTag::kMinComputeCapability >= 90) { - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.pipelines.barrier_O.arrive(cta_id); - } - } - } - - flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; - bool is_varlen = Varlen && params.cu_seqlens; - int offset_o = seqlen_info.offset; - int seqlen_o = seqlen_info.seqlen; - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - - // Step 2: Write LSE from rmem -> gmem - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - // (MMA,MMA_M,MMA_K) - Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - static_assert(decltype(size<0, 0>(taccOcO))::value == 2); - static_assert(decltype(size<0, 1>(taccOcO))::value == 2); - Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); - Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - - using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; - using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; - - Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), - params.shape_LSE_packed, - !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); - // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } - if (!LargeHeadDimV || warp_group_idx == 0) { - if constexpr (!PackGQA) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); - if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } - } - } else { - PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); - } - } - - // Step 3: Write O from smem -> gmem - if constexpr (Use_TMA_O) { - Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - auto block_tma_O = params.tma_store_O.get_slice(_0{}); - Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) - Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) - int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); - if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { - cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - if (cute::elect_one_sync()) { - cute::copy(params.tma_store_O, tOsO, tOgO); - tma_store_arrive(); - tma_store_wait<0>(); - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.pipelines.barrier_O.arrive(cta_id); - } - } - } - } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence - if (!is_split) { - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) - // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOrO = make_fragment_like(tOsO); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - if constexpr (ArchTag::kMinComputeCapability >= 90) { - cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.pipelines.barrier_O.arrive(cta_id); - } - } - if constexpr (!PackGQA) { - // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - } else { - // If PackGQA, we split the work of compute O_ptr among threads in the same row - PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); - } - } else { - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - // We already arrived on barrier_O earlier if !Use_smem - if constexpr (Use_smem) { - if constexpr (ArchTag::kMinComputeCapability >= 90) { - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.pipelines.barrier_O.arrive(cta_id); - } - } - } - if constexpr (!PackGQA) { - static constexpr int kGmemElemsPerStoreDirect = 2; - cute::Copy_Atom, ElementPartial> gmem_copy_direct; - // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); - Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); - Tensor tOgO = thread_mma.partition_C(gOpartial); - Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); - Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); - Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); - #pragma unroll - for (int m = 0; m < size(taccOcO_row); ++m) { - if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { - #pragma unroll - for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { - if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { - cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); - } - } - } - } - } else { - PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); - } - } - } - } - - CUTLASS_DEVICE void - store_tail() { - // Don't need to do tma_store_wait<0>() here since we already did in @store - } - - // Write 0 to output and -inf to LSE - CUTLASS_DEVICE void - store_zero( - Params const& params, - int thread_idx, - cute::tuple const& block_coord - ) { - static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); - auto [m_block, bidh, bidb, split_idx] = block_coord; - int num_splits = get<4>(params.shape_O_packed); - if constexpr (Split && Varlen) { - uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits - int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); - num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; - split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx - } - bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); - - flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; - bool const is_varlen = Varlen && params.cu_seqlens; - int offset_o = seqlen_info.offset; - int seqlen_o = seqlen_info.seqlen; - int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; - Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), - params.shape_LSE_packed, - !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); - Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); - - static_assert(kBlockM <= NumEpilogueThreads); - if (thread_idx < kBlockM) { - const int row = m_block * kBlockM + thread_idx; - if constexpr (!PackGQA) { - if (row < seqlen_o) { mLSE(row) = -INFINITY; } - } else { - if (row < seqlen_o * qhead_per_khead) { - int m_idx, h_idx; - m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); - // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" - mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; - } - } - } - - // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, - // since it will not use the value of O if LSE is -inf. - if (!is_split) { - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); - - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - if constexpr (!PackGQA) { - Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_fragment_like(tOgO); - cute::clear(tOrO); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - } else { - // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; - Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); - cute::clear(tOrO); - PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); - } - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash.h b/flash-attn/flash.h deleted file mode 100644 index 28997613dc6c63a122fd21f8b4f9ce05a4fa771b..0000000000000000000000000000000000000000 --- a/flash-attn/flash.h +++ /dev/null @@ -1,220 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Qkv_params { - using index_t = int64_t; - // The QKV matrices. - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; - index_t v_dim_stride; - - // The number of heads. - int h, h_k; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_fwd_params : public Qkv_params { - using index_t = int64_t; - - // The O matrix (output). - void * __restrict__ o_ptr; - void * __restrict__ oaccum_ptr; - - // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; - - // The pointer to the softmax sum. - void * __restrict__ softmax_lse_ptr; - void * __restrict__ softmax_lseaccum_ptr; - - // For FP8 scaling - float * __restrict__ q_descale_ptr; - float * __restrict__ k_descale_ptr; - float * __restrict__ v_descale_ptr; - index_t q_descale_batch_stride; - index_t q_descale_head_stride; - index_t k_descale_batch_stride; - index_t k_descale_head_stride; - index_t v_descale_batch_stride; - index_t v_descale_head_stride; - - // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; - int total_q, total_k, total_knew; - int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q - int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim - - // The scaling factors for the kernel. - float scale_softmax; - float softcap; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens_q; - int * __restrict__ cu_seqlens_k; - int * __restrict__ cu_seqlens_knew; - int * __restrict__ leftpad_k; - - // If provided, the actual length of each q/k sequence. - int *__restrict__ seqused_q; - int *__restrict__ seqused_k; - - // The stride between rows of Oaccum. - index_t oaccum_split_stride; - index_t oaccum_batch_stride; - index_t oaccum_row_stride; - index_t oaccum_head_stride; - - // The stride between rows of LSEaccum. - index_t lseaccum_split_stride; - index_t lseaccum_batch_stride; - index_t lseaccum_head_stride; - - // The K_new and V_new matrices. - void * __restrict__ knew_ptr; - void * __restrict__ vnew_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; - - void *__restrict__ qv_ptr; - index_t qv_batch_stride; - index_t qv_row_stride; - index_t qv_head_stride; - - // The cos and sin matrices for rotary embedding. - void * __restrict__ rotary_cos_ptr; - void * __restrict__ rotary_sin_ptr; - int *__restrict__ seqlens_rotary; - - // The indices to index into the KV cache. - int * __restrict__ kv_batch_idx; - - // Paged KV cache - int * __restrict__ page_table; - index_t page_table_batch_stride; - int page_size; - int num_pages; - bool pagedkv_tma; - - // The dropout probability (probability of keeping an activation). - float p_dropout; - // uint32_t p_dropout_in_uint; - // uint16_t p_dropout_in_uint16_t; - uint8_t p_dropout_in_uint8_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - - // Local window size - int window_size_left, window_size_right; - - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t * rng_state; - - bool is_bf16; - bool is_fp32; - bool is_e4m3; - bool is_causal; - bool is_local; - - bool is_rotary_interleaved; - - int num_splits; // For split-KV version - bool pack_gqa; - - int * __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; - // int * __restrict__ num_n_blocks_ptr; - int * __restrict__ num_splits_dynamic_ptr; - bool skip_scheduler_metadata_computation; - - int arch; - int num_sm; - - // The S extra matrix, (num_heads) - void *__restrict__ s_aux_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_bwd_params : public Flash_fwd_params { - using index_t = int64_t; - - // The dO and dQKV matrices. - void *__restrict__ do_ptr; - void *__restrict__ dq_ptr; - void *__restrict__ dk_ptr; - void *__restrict__ dv_ptr; - - // To accumulate dQ - void *__restrict__ dq_accum_ptr; - void *__restrict__ dk_accum_ptr; - void *__restrict__ dv_accum_ptr; - - // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q - // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ - // dv_accum_ptr; - - // The stride between rows of the dO, dQ, dK and dV matrices. - index_t do_batch_stride; - index_t do_row_stride; - index_t do_head_stride; - index_t dq_batch_stride; - index_t dk_batch_stride; - index_t dv_batch_stride; - index_t dq_row_stride; - index_t dk_row_stride; - index_t dv_row_stride; - index_t dq_head_stride; - index_t dk_head_stride; - index_t dv_head_stride; - - // The pointer to the softmax d sum. - void *__restrict__ dsoftmax_sum; - void *__restrict__ softmax_lse_log2_ptr; - - int *__restrict__ dq_semaphore; - int *__restrict__ dk_semaphore; - int *__restrict__ dv_semaphore; - - bool deterministic; - index_t dq_accum_split_stride; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); -template -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/flash-attn/flash_api.cpp b/flash-attn/flash_api.cpp deleted file mode 100644 index 07878dea63e0fa63c67745b7ab9e0a37ae30af3b..0000000000000000000000000000000000000000 --- a/flash-attn/flash_api.cpp +++ /dev/null @@ -1,1623 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -#include -#include // For TORCH_VERSION* macros -#include -#include - -#include - -#include "flash.h" -#include "static_switch.h" -#include "tile_size.h" -#include "heuristics.h" -#include "cuda_check.h" - -// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 -// This is so that we can pass in torch.dtype as a parameter to the function. -#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4) - -#include -#include - -namespace pybind11::detail { - - template <> - struct type_caster { - public: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); - // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType - // cannot be default-initialized, we provide this constructor to explicitly - // initialize that field. The value doesn't matter as it will be overwritten - // after a successful call to load. - type_caster() : value(at::kFloat) {} - bool load(handle src, bool) { - PyObject* obj = src.ptr(); - if (THPDtype_Check(obj)) { - value = reinterpret_cast(obj)->scalar_type; - return true; - } - return false; - } - static handle cast( - const at::ScalarType& src, - return_value_policy /* policy */, - handle /* parent */) { - return Py_NewRef(torch::getTHPDtype(src)); - } - }; - -} // namespace pybind11::detail - -#endif - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -void set_params_fprop(Flash_fwd_params ¶ms, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - at::Tensor out, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *seqused_q, - void *seqused_k, - void *softmax_lse_d, - float p_dropout, - float softmax_scale, - int window_size_left, - int window_size_right, - const float softcap=0.f, - const int sm_margin=0) { - - // Reset the parameters - params = {}; - - params.is_bf16 = q.dtype() == torch::kBFloat16; - params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.v_dim_stride = v.stride(-1); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - if (cu_seqlens_q_d == nullptr) { - params.q_batch_stride = q.stride(0); - params.o_batch_stride = out.stride(0); - } - if (cu_seqlens_k_d == nullptr) { - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - } - - params.cu_seqlens_q = static_cast(cu_seqlens_q_d); - params.cu_seqlens_k = static_cast(cu_seqlens_k_d); - params.seqused_q = static_cast(seqused_q); - params.seqused_k = static_cast(seqused_k); - - // Softmax sum - params.softmax_lse_ptr = softmax_lse_d; - - // Set the dimensions. - params.b = b; - params.h = h; - params.h_k = h_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; - params.d_rounded = d_rounded; - - // Set the different scale values. - params.scale_softmax = softmax_scale; - params.softcap = softcap; - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead of < - // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); - // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - TORCH_CHECK(p_dropout < 1.f); - #ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); - #endif - - // Causal is the special case where window_size_right == 0 and window_size_left < 0. - // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - - // TODO: check this - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - - params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; - - #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); - #endif -} - -void set_params_dgrad(Flash_bwd_params ¶ms, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - const at::Tensor out, - const at::Tensor dout, - at::Tensor dq, - at::Tensor dk, - at::Tensor dv, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *seqused_q, - void *seqused_k, - void *dq_accum_d, - void *dk_accum_d, - void *dv_accum_d, - void *softmax_lse_d, - void *dsoftmax_sum_d, - float p_dropout, - float softmax_scale, - int window_size_left, - int window_size_right, - const float softcap=0.f, - bool deterministic=false, - int const sm_margin=0) { - - set_params_fprop(params, - b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, - q, k, v, out, - cu_seqlens_q_d, - cu_seqlens_k_d, - seqused_q, - seqused_k, - softmax_lse_d, - p_dropout, - softmax_scale, - window_size_left, - window_size_right, - softcap, - sm_margin); - - // Set the pointers and strides. - params.do_ptr = dout.data_ptr(); - params.do_row_stride = dout.stride(-3); - params.do_head_stride = dout.stride(-2); - params.dq_ptr = dq.data_ptr(); - params.dk_ptr = dk.data_ptr(); - params.dv_ptr = dv.data_ptr(); - params.dq_row_stride = dq.stride(-3); - params.dk_row_stride = dk.stride(-3); - params.dv_row_stride = dv.stride(-3); - params.dq_head_stride = dq.stride(-2); - params.dk_head_stride = dk.stride(-2); - params.dv_head_stride = dv.stride(-2); - - if (cu_seqlens_q_d == nullptr) { - params.do_batch_stride = dout.stride(0); - params.dq_batch_stride = dq.stride(0); - params.dk_batch_stride = dk.stride(0); - params.dv_batch_stride = dv.stride(0); - } - - params.dq_accum_ptr = dq_accum_d; - params.dk_accum_ptr = dk_accum_d; - params.dv_accum_ptr = dv_accum_d; - - // Softmax sum - params.dsoftmax_sum = dsoftmax_sum_d; - - params.deterministic = deterministic; -} - -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // HEADDIM_SWITCH(params.d, [&] { - // run_mha_fwd_(params, stream); - // }); - TORCH_CHECK(params.num_splits >= 1); - ARCH_SWITCH(params.arch, Arch, [&] { - SPLIT_SWITCH(params.num_splits > 1, Split, [&] { - PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { - PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { - // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation - static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; - SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { - if (!params.is_e4m3) { - if (params.is_bf16) { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { - if (params.dv > 256 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { - if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } - #endif - } else { - #ifndef FLASHATTENTION_DISABLE_FP16 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { - if (params.dv > 256 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { - if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); - #endif - } - } else { - #ifndef FLASHATTENTION_DISABLE_FP8 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { - if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); - } else { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); - } - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP8."); - #endif - } - }); - }); - }); - }); - }); -} - -void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { - #ifndef FLASHATTENTION_DISABLE_SPLIT - // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively - // so that kBlockM is smaller and we have more parallelism. - if (params.is_fp32) { - if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream, enable_pdl); - } else { - run_mha_fwd_combine_(params, stream, enable_pdl); - } - } else if (params.is_bf16) { - if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream, enable_pdl); - } else { - run_mha_fwd_combine_(params, stream, enable_pdl); - } - } else { - if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream, enable_pdl); - } else { - run_mha_fwd_combine_(params, stream, enable_pdl); - } - } - #else - TORCH_CHECK(false, "This flash attention build does not support combine kernels."); - #endif -} - -inline bool get_pagedkv_tma(Flash_fwd_params const& params) { - if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } - // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); - int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); - int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); - // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, - // at least for MLA. - return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; -} - -inline bool get_pack_gqa(Flash_fwd_params const& params) { - // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. - // Has little effect on speed. - if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } - #ifdef FLASHATTENTION_DISABLE_PACKGQA - return false; - #else - // params.page_table must already be set - if (params.h == params.h_k) { return false; } - // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); - int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); - return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); - #endif -} - -inline int get_num_splits(Flash_fwd_params const& params) { - #ifdef FLASHATTENTION_DISABLE_SPLIT - return 1; - #else - // Always enable PackGQA for Split - // params.page_table must already be set - // This needs to match the kernel configs - bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); - // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits - // has not been set here. It's OK though because we might just underestimate kBlockN a bit - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); - int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); - int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); - int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); - // If is_local, we're not going to load all of seqlen_k - int const seqlen_k_loaded = !params.is_local - ? params.seqlen_k - : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); - int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; - int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; - int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); - // Always enable PackGQA for Split - // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. - // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending - // that batch = 1. - int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; - return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); - #endif -} - -inline int get_max_headdim() { - #ifndef FLASHATTENTION_DISABLE_HDIM256 - return 256; - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - return 192; - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - return 128; - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - return 96; - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM64 - return 64; - #endif - return 0; -} - -inline int round_up_headdim(int head_size) { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (head_size <= 64) { return 64; } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (head_size <= 96) { return 96; } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (head_size <= 128) { return 128; } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (head_size <= 192) { return 192; } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (head_size <= 256) { return 256; } - #endif - return 256; -} - -inline int round_up_headdimv(int head_size) { - if (head_size <= 64) { return 64; } - if (head_size <= 96) { return 96; } - if (head_size <= 128) { return 128; } - if (head_size <= 192) { return 192; } - if (head_size <= 256) { return 256; } - return 512; -} - -// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available -at::Tensor -mha_fwd_get_scheduler_metadata( - int batch_size, - int max_seqlen_q, - int max_seqlen_k, - int num_heads, - int num_heads_k, - int headdim, - int headdim_v, - at::ScalarType qkv_dtype, - const at::Tensor &seqused_k, // b - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &leftpad_k_, // b - std::optional page_size, - int max_seqlen_k_new, // 0 means we're not appending new KV - bool is_causal, - int window_size_left, - int window_size_right, - bool has_softcap, - int num_splits, - std::optional pack_gqa_, - int const sm_margin - ) { - - TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, - "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - // Reset the parameters - Flash_fwd_params params{}; - params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; - params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; - params.b = batch_size; - params.seqlen_q = max_seqlen_q; - params.seqlen_k = max_seqlen_k; - params.h = num_heads; - params.h_k = num_heads_k; - params.d = headdim; - params.dv = headdim_v; - params.d_rounded = round_up_headdim(headdim); - params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); - params.seqlen_knew = max_seqlen_k_new; - - bool const is_varlen_q = cu_seqlens_q_.has_value(); - params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr() : nullptr; - bool const is_varlen_k = cu_seqlens_k_.has_value(); - params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr() : nullptr; - params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr() : nullptr; - params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr; - params.seqused_k = seqused_k.data_ptr(); - params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr() : nullptr; - params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; - if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } - if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } - // causal=true is the same as causal=false in this case - if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { - // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA - if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { - is_causal = false; - } - } - if (is_causal) { window_size_right = 0; } - - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; - params.softcap = has_softcap ? 1.0f : 0.0f; - - params.page_size = page_size.has_value() ? page_size.value() : 1; - params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - - bool const use_dynamic_split = params.b <= 992; - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); - - params.pagedkv_tma = get_pagedkv_tma(params); - // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split - params.pack_gqa = params.num_splits > 1; - - bool is_varlen = true; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; - - auto opts = seqused_k.options(); - // This needs to be set after get_num_splits - at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic - bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; - if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); - if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - } else { - params.tile_count_semaphore = nullptr; - } - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; - } - - if (params.num_splits_dynamic_ptr) { - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); - int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); - int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); - } - return tile_count_semaphore; -} - -// b: batch_size -// b_k: batch_size_k -// s_q: seqlen_q -// s_k: seqlen_k -// s_k_new: seqlen_k_new -// h: num_heads -// h_k: num_heads_k -// d: head_size -std::vector -mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. - std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional &page_table_, // (b_k, max_num_pages_per_seq) - std::optional &kv_batch_idx_, // b. indices to index into the KV cache - std::optional &leftpad_k_, // b - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &seqlens_rotary_, // b - std::optional &q_descale_, // (b, h_k), not (b, h) - std::optional &k_descale_, // (b, h_k) - std::optional &v_descale_, // (b, h_k) - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional &scheduler_metadata_, // (b + 1) - int num_splits, - std::optional pack_gqa_, - int const sm_margin, - std::optional &s_aux_ // (h) - ) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major >= 8; - TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - - auto q_type = q.scalar_type(); - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, - "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); - if (dprops->major < 9) { - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); - } - TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - at::Tensor page_table; - const bool paged_KV = page_table_.has_value(); - if (paged_KV) { - page_table = page_table_.value(); - CHECK_DEVICE(page_table); - TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); - TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); - } - - at::Tensor cu_seqlens_q; - bool const is_varlen_q = cu_seqlens_q_.has_value(); - if (is_varlen_q) { - cu_seqlens_q = cu_seqlens_q_.value(); - CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); - } - at::Tensor cu_seqlens_k; - bool const is_varlen_k = cu_seqlens_k_.has_value(); - if (is_varlen_k) { - cu_seqlens_k = cu_seqlens_k_.value(); - CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); - TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); - TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); - } - - auto const sizes = q.sizes(); - const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; - int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); - int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; - int num_heads = q.size(-2); - int const head_size = q.size(-1); - int const head_size_v = v.size(-1); - int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); - int const num_pages = !paged_KV ? 0 : k.size(0); - int const page_size = !paged_KV ? 1 : k.size(1); - int const seqlen_k = !max_seqlen_k_.has_value() ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); - int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); - int const num_heads_k = k.size(-2); - int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); - if (!kv_batch_idx_.has_value()) { - TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); - } - int const max_headdim = get_max_headdim(); - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (head_size_v != head_size) { - TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || - (head_size <= 64 && head_size_v <= 512), - "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " - "or (Q/K <= 64 and V <= 512)."); - TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); - if (head_size_v > 256) { - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "HeaddimV > 256 requires fp16 and bf16 data type"); - } - } - - // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM - // TODO: check this - if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } - if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } - // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { - // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA - if ((head_size <= 64 || head_size > 128) || !paged_KV) { - is_causal = false; - } - } - if (is_causal) { window_size_right = 0; } - // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true. - // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM. - is_causal = window_size_left < 0 && window_size_right == 0; - - if (!is_varlen_q) { - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - } else { - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - } - if (!paged_KV) { - if (!is_varlen_k) { - CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); - } else { - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - } - } else { - CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); - CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); - } - - if (seqused_q_.has_value()){ - auto seqused_q = seqused_q_.value(); - TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); - CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); - CHECK_SHAPE(seqused_q, batch_size); - } - if (seqused_k_.has_value()) { - auto seqused_k = seqused_k_.value(); - TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); - CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); - CHECK_SHAPE(seqused_k, batch_size); - } - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - } - - // This is what we will template on - bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); - #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); - #endif - - int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; - TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); - TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); - - auto opts = q.options(); - auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); - } else { - CHECK_SHAPE(out, total_q, num_heads, head_size_v); - } - } else { - out = !is_varlen_q - ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) - : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); - int const seqlen_q_rounded = round_multiple(seqlen_q, 128); - int const seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - at::Tensor softmax_lse; - if (!is_varlen_q) { - softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - } else { - softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - } - - Flash_fwd_params params; - set_params_fprop(params, - batch_size, - seqlen_q, seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, out, - !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), - !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), - seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, - seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, - softmax_lse.data_ptr(), - /*p_dropout=*/0.f, - softmax_scale, - window_size_left, - window_size_right, - softcap, - sm_margin); - params.total_q = total_q; - params.total_k = total_k; - params.b_k = batch_size_k; - params.dv = head_size_v; - params.dv_rounded = head_size_v_rounded; - if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma - params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); - } - if (paged_KV) { - params.page_table = page_table.data_ptr(); - params.page_table_batch_stride = page_table.stride(0); - } - params.page_size = page_size; - params.num_pages = num_pages; - - if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma - at::Tensor k_new, v_new; - TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); - TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); - TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); - at::Tensor cu_seqlens_k_new; - bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); - if (is_varlen_k_new) { - cu_seqlens_k_new = cu_seqlens_k_new_.value(); - CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); - TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); - } - k_new = k_new_.value(); - v_new = v_new_.value(); - TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); - TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); - CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); - TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); - TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); - // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new - int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; - int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); - if (!is_varlen_k_new) { - CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); - } else { - CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); - CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); - } - params.seqlen_knew = seqlen_k_new; - params.total_knew = total_k_new; - params.knew_ptr = k_new.data_ptr(); - params.vnew_ptr = v_new.data_ptr(); - // All stride are in elements, not bytes. - params.knew_row_stride = k_new.stride(-3); - params.vnew_row_stride = v_new.stride(-3); - params.knew_head_stride = k_new.stride(-2); - params.vnew_head_stride = v_new.stride(-2); - if (!is_varlen_k_new) { - params.knew_batch_stride = k_new.stride(0); - params.vnew_batch_stride = v_new.stride(0); - } - if (is_varlen_k_new) { - params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); - } - } - - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; - // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); - - params.pagedkv_tma = get_pagedkv_tma(params); - // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split - params.pack_gqa = params.num_splits > 1; - - // This needs to be set after get_num_splits - at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic - // We don't use the persistent scheduler if Split and not Varlen - bool const scheduler_needs_semaphore = params.arch >= 90 - ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) - : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; - params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); - if (scheduler_metadata_.has_value()) { - at::Tensor scheduler_metadata = scheduler_metadata_.value(); - CHECK_DEVICE(scheduler_metadata); - CHECK_SHAPE(scheduler_metadata, metadata_size); - CHECK_CONTIGUOUS(scheduler_metadata); - TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); - tile_count_semaphore = scheduler_metadata; - } else { - tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); - } - if (scheduler_needs_semaphore && !use_dynamic_split) { - tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing - } - params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; - } - - if (q_v_.has_value()) { - TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); - at::Tensor q_v = q_v_.value(); - TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); - CHECK_DEVICE(q_v); - TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); - if (!is_varlen_q) { - CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); - } else { - CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); - } - params.qv_ptr = q_v.data_ptr(); - // All stride are in elements, not bytes. - params.qv_row_stride = q_v.stride(-3); - params.qv_head_stride = q_v.stride(-2); - if (!is_varlen_q) { - params.qv_batch_stride = q_v.stride(0); - } - } - - if (rotary_cos_.has_value()) { - TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); - auto rotary_cos = rotary_cos_.value(); - CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); - params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); - const int seqlen_ro = rotary_cos.size(0); - if (paged_KV) { - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); - } - CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); - auto rotary_sin = rotary_sin_.value(); - CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); - CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - params.rotary_cos_ptr = rotary_cos.data_ptr(); - params.rotary_sin_ptr = rotary_sin.data_ptr(); - params.is_rotary_interleaved = is_rotary_interleaved; - if (seqlens_rotary_.has_value()) { - at::Tensor seqlens_rotary = seqlens_rotary_.value(); - CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); - TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); - CHECK_SHAPE(seqlens_rotary, batch_size); - params.seqlens_rotary = seqlens_rotary.data_ptr(); - } - } else { - params.rotary_dim = 0; - } - - if (kv_batch_idx_.has_value()) { - auto kv_batch_idx = kv_batch_idx_.value(); - CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); - TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); - params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); - } - - at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - if (params.num_splits > 1) { - TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); - if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); - softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - params.oaccum_batch_stride = out_accum.stride(1); - params.lseaccum_batch_stride = softmax_lse_accum.stride(1); - } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); - softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); - } - params.is_fp32 = false; - params.oaccum_ptr = out_accum.data_ptr(); - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_split_stride = out_accum.stride(0); - params.oaccum_row_stride = out_accum.stride(-2); - params.oaccum_head_stride = out_accum.stride(-3); - params.lseaccum_split_stride = softmax_lse_accum.stride(0); - params.lseaccum_head_stride = softmax_lse_accum.stride(-2); - } - - if (q_type == at::ScalarType::Float8_e4m3fn) { - if (q_descale_.has_value()) { - auto q_descale = q_descale_.value(); - CHECK_DEVICE(q_descale); - CHECK_SHAPE(q_descale, batch_size, num_heads_k); - params.q_descale_ptr = q_descale.data_ptr(); - params.q_descale_batch_stride = q_descale.stride(0); - params.q_descale_head_stride = q_descale.stride(1); - } else { - params.q_descale_ptr = nullptr; - } - if (k_descale_.has_value()) { - auto k_descale = k_descale_.value(); - CHECK_DEVICE(k_descale); - CHECK_SHAPE(k_descale, batch_size, num_heads_k); - params.k_descale_ptr = k_descale.data_ptr(); - params.k_descale_batch_stride = k_descale.stride(0); - params.k_descale_head_stride = k_descale.stride(1); - } else { - params.k_descale_ptr = nullptr; - } - if (v_descale_.has_value()) { - auto v_descale = v_descale_.value(); - CHECK_DEVICE(v_descale); - CHECK_SHAPE(v_descale, batch_size, num_heads_k); - params.v_descale_ptr = v_descale.data_ptr(); - params.v_descale_batch_stride = v_descale.stride(0); - params.v_descale_head_stride = v_descale.stride(1); - } else { - params.v_descale_ptr = nullptr; - } - } - - if(s_aux_.has_value()) { - auto s_aux = s_aux_.value(); - TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16, - "We only support bf16 dtype for S extra."); - CHECK_DEVICE(s_aux); - CHECK_SHAPE(s_aux, num_heads); - CHECK_CONTIGUOUS(s_aux); - params.s_aux_ptr = s_aux.data_ptr(); - } else { - params.s_aux_ptr = nullptr; - } - - #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); - #endif - #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); - #endif - #ifdef FLASHATTENTION_DISABLE_SPLIT - TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); - #endif - #ifdef FLASHATTENTION_DISABLE_PACKGQA - TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); - #endif - #ifdef FLASHATTENTION_DISABLE_PAGEDKV - TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); - #endif - #ifdef FLASHATTENTION_DISABLE_APPENDKV - TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); - #endif - - if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); - if (params.num_splits > 1) { - if (out_type == at::ScalarType::BFloat16) { - // Since we want output in BF16. Otherwise fwd_combine will output to FP16 - params.is_bf16 = true; - } - // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 - // and seqlen = total_q, and don't need to dispatch to Varlen there. - // However, with dynamic split, each row needs to know which batch it belongs to - // to read the number of splits, so we just use the varlen version of combine kernel. - // if (is_varlen_q && !seqused_q_.has_value()) { - // if (is_varlen_q) { - // params.b = 1; - // params.seqlen_q = total_q; - // } - // This will zero out the semaphore if needed - run_mha_fwd_combine(params, stream, true /*enable_pdl*/); - } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { - // need to zero out the semaphore in this case - tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); - } - } else if (total_q > 0 && num_heads_k > 0) { - // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); - } - - // return {out, softmax_lse}; - return {out, softmax_lse, out_accum, softmax_lse_accum}; -} - -void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { - #ifndef FLASHATTENTION_DISABLE_BACKWARD - // FP16_SWITCH(!params.is_bf16, [&] { - // HEADDIM_SWITCH(params.d, [&] { - // run_mha_bwd_(params, stream); - // }); - // }); - ARCH_SWITCH(params.arch, Arch, [&] { - SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { - if (!params.is_bf16) { - #ifndef FLASHATTENTION_DISABLE_FP16 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); - #endif - } else { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } - #endif - } - }); - }); - #endif -} - - -// b: batch_size -// s_q: seqlen_q -// s_k: seqlen_k -// h: num_heads -// h_k: num_heads_k -// d: head_size -std::vector mha_bwd( - const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - std::optional &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - std::optional max_seqlen_k_, - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const deterministic, - int const sm_margin) { - - #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash attention build does not support backward."); - #endif - - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major >= 8; - TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - - auto q_type = q.dtype(); - TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - - at::Tensor cu_seqlens_q; - bool const is_varlen_q = cu_seqlens_q_.has_value(); - if (is_varlen_q) { - cu_seqlens_q = cu_seqlens_q_.value(); - CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); - } - at::Tensor cu_seqlens_k; - bool const is_varlen_k = cu_seqlens_k_.has_value(); - if (is_varlen_k) { - cu_seqlens_k = cu_seqlens_k_.value(); - CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); - } - // This is what we will template on - bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); - #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); - #endif - - auto const sizes = q.sizes(); - int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; - int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); - int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; - int const num_heads = q.size(-2); - int const head_size = q.size(-1); - int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); - int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); - int const num_heads_k = k.size(-2); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - int const max_headdim = get_max_headdim(); - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM - if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } - if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } - if (is_causal) { window_size_right = 0; } - // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. - // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). - is_causal = window_size_left < 0 && window_size_right == 0; - - int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - int const head_size_rounded = round_up_headdim(head_size); - // Very important that these match the kernel configs - bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; - int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) - : (head_size_rounded <= 96 ? 64 - : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) - : 64)); - int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; - int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; - int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); - int const kBlockN_sm90 = head_size_rounded <= 128 - ? 128 - : (head_size_rounded <= 192 ? 96 : 80); - int const kBlockN_sm80 = head_size_rounded <= 128 - ? 128 - : (head_size_rounded <= 192 ? 80 : 64); - int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 - : (head_size_rounded <= 96 ? 128 - : (head_size_rounded <= 128 ? 96 - : (head_size_rounded <= 192 ? 64 : 64))); - int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); - int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); - int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); - int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); - - if (!is_varlen_q) { - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - } else { - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - } - if (!is_varlen_k) { - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - } else { - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - } - - if (seqused_q_.has_value()){ - auto seqused_q = seqused_q_.value(); - TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); - CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); - CHECK_SHAPE(seqused_q, batch_size); - } - if (seqused_k_.has_value()){ - auto seqused_k = seqused_k_.value(); - TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); - CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); - CHECK_SHAPE(seqused_k, batch_size); - } - - at::Tensor dq, dk, dv; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - if (!is_varlen_q) { - CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); - } else { - CHECK_SHAPE(dq, total_q, num_heads, head_size); - } - } else { - dq = torch::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - if (!is_varlen_k) { - CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); - } else { - CHECK_SHAPE(dk, total_k, num_heads_k, head_size); - } - } else { - dk = torch::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - if (!is_varlen_k) { - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); - } else { - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); - } - } else { - dv = torch::empty_like(v); - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 - at::Tensor softmax_d, softmax_lse_log2; - if (!is_varlen) { - // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 - softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - } else { - softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); - softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); - } - at::Tensor dq_accum, dk_accum, dv_accum; - if (!is_varlen) { - dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat)); - } else { - dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat)); - } - if (num_heads_k != num_heads) { // MQA / GQA - if (!is_varlen) { - dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); - dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); - } else { - dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - } - } - - Flash_bwd_params params; - set_params_dgrad(params, - batch_size, - seqlen_q, seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, out, - dout, dq, dk, dv, - !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), - !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), - seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, - seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, - dq_accum.data_ptr(), - num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, - num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, - softmax_lse.data_ptr(), - softmax_d.data_ptr(), - /*p_dropout=*/0.f, - softmax_scale, - window_size_left, - window_size_right, - softcap, - deterministic, - sm_margin); - params.total_q = total_q; - params.total_k = total_k; - params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); - params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now - - // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); - // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - // Will be zero'ed out in the backward preprocess kernel - at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); - params.dq_semaphore = dq_semaphore.data_ptr(); - if (num_heads_k != num_heads && params.deterministic) { - // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - params.dk_semaphore = dk_semaphore.data_ptr(); - params.dv_semaphore = dv_semaphore.data_ptr(); - } - - #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); - #endif - #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); - #endif - - if (total_q > 0 && total_k > 0 && num_heads_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_bwd(params, stream); - } else if (total_k > 0 && num_heads_k > 0) { - // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk.zero_(); - dv.zero_(); - softmax_d.zero_(); - } else if (total_q > 0 && num_heads_k > 0) { - dq.zero_(); - softmax_d.zero_(); - } - - return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; -} - -std::vector -mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size - const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads - std::optional out_, // batch_size x seqlen x num_heads x head_size - std::optional out_dtype_ - ) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major >= 8; - TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); - - auto out_partial_type = out_partial.scalar_type(); - TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type"); - TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type"); - - CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); - - TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); - - const auto sizes = out_partial.sizes(); - - const int num_splits = sizes[0]; - const int batch_size = sizes[1]; - const int seqlen = sizes[2]; - const int num_heads = sizes[3]; - const int head_size_og = sizes[4]; - TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); - - CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); - CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); - - int const alignment = 4; - at::Tensor out_partial_padded; - auto pad = [](at::Tensor x, int alignment) { - return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment})); - }; - out_partial_padded = pad(out_partial, alignment); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, alignment); - - auto opts = out_partial.options(); - at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); - TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.scalar_type() == out_type); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); - if (head_size_og % alignment != 0) { - out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); - } - } else { - out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()}; - - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2); - - Flash_fwd_params params {}; // Need to reset the params to set everything to zero - params.is_fp32 = out_type == at::ScalarType::Float; - params.is_bf16 = out_type == at::ScalarType::BFloat16; - params.oaccum_ptr = out_partial_padded.data_ptr(); - params.softmax_lseaccum_ptr = lse_partial.data_ptr(); - params.o_ptr = out.data_ptr(); - params.softmax_lse_ptr = softmax_lse.data_ptr(); - params.b = batch_size; - params.h = num_heads; - params.seqlen_q = seqlen; - params.dv = head_size; - params.num_splits = num_splits; - params.oaccum_split_stride = out_partial_padded.stride(0); - params.oaccum_row_stride = out_partial_padded.stride(2); - params.oaccum_head_stride = out_partial_padded.stride(3); - params.oaccum_batch_stride = out_partial_padded.stride(1); - params.lseaccum_split_stride = lse_partial.stride(0); - params.lseaccum_head_stride = lse_partial.stride(3); - params.lseaccum_batch_stride = lse_partial.stride(1); - params.o_row_stride = out.stride(1); - params.o_head_stride = out.stride(2); - params.o_batch_stride = out.stride(0); - params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - - if (seqlen > 0 && batch_size > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd_combine(params, stream, false /*enable_pdl*/); - } - - at::Tensor out_padded = out; - if (head_size_og % alignment != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - // if (out_.has_value()) { out_.value().copy_(out); } - } - - return {out, softmax_lse}; -} - -#ifndef FLASHATTENTION_DISABLE_PYBIND - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashAttention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("fwd_combine", &mha_combine, "Combine partial attention outputs"); - m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); -} - -#endif \ No newline at end of file diff --git a/flash-attn/flash_bwd_kernel_sm80.h b/flash-attn/flash_bwd_kernel_sm80.h deleted file mode 100644 index aaec00dbe4a493a159629a19df75e5eaafac2c83..0000000000000000000000000000000000000000 --- a/flash-attn/flash_bwd_kernel_sm80.h +++ /dev/null @@ -1,173 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include - -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnBwdSm80 { - -public: - - // Type Aliases - static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; - static constexpr bool Is_local = CollectiveMainloop_::Is_local; - static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); - static constexpr bool Varlen = CollectiveMainloop_::Varlen; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; - using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; - using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; - using ArchTag = typename CollectiveMainloop::ArchTag; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 80); - - using TileScheduler = TileScheduler_; - using TileSchedulerArguments = typename flash::TileSchedulerArguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{})); - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Kernel level shared memory storage - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { - union { - typename CollectiveMainloop::TensorStorage mainloop; - typename CollectiveEpilogue::TensorStorage epilogue; - }; - } tensors; - - alignas(16) typename TileScheduler::SharedStorage smem_scheduler; - - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - return { - CollectiveMainloop::to_underlying_arguments(args.mainloop), - CollectiveEpilogue::to_underlying_arguments(args.epilogue), - hw_info, - TileScheduler::to_underlying_arguments(args.scheduler) - }; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - CollectiveMainloop mainloop; - CollectiveEpilogue epilogue; - - TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); - // Initialize matmul objects. - TiledMmadKV tiled_mma_dKV; - - scheduler.init_consumer(); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - CUTLASS_PRAGMA_NO_UNROLL - for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { - - auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); - auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; - cute::tuple block_coord = {n_block, bidh, bidb}; - - // dK and dV output accumulator. - Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, - block_coord, shared_storage); - scheduler.prefetch_next_work(params.scheduler, work_tile_info); - if (tile_valid) { - epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, - threadIdx.x, block_coord); - } else { - epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); - } - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_bwd_kernel_sm90.h b/flash-attn/flash_bwd_kernel_sm90.h deleted file mode 100644 index b93a0219161556b0d14f33008e0a448294d609ef..0000000000000000000000000000000000000000 --- a/flash-attn/flash_bwd_kernel_sm90.h +++ /dev/null @@ -1,282 +0,0 @@ - -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include -#include -#include -#include "cutlass/pipeline/pipeline.hpp" - -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnBwdSm90 { - -public: - - // Type Aliases - static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; - static constexpr bool Is_local = CollectiveMainloop_::Is_local; - static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); - static constexpr bool Varlen = CollectiveMainloop_::Varlen; - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; - using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; - using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ClusterShape = typename CollectiveMainloop::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - using TileScheduler = TileScheduler_; - using TileSchedulerArguments = typename flash::TileSchedulerArguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32; - static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160; - // If you want to print from the producer warp, you'd need to increase the number of registers - // Otherwise you'll get CUDA error. - // static constexpr uint32_t LoadRegisterRequirement = 40; - // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; - - // Kernel level shared memory storage - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { - union { - typename CollectiveMainloop::TensorStorage mainloop; - typename CollectiveEpilogue::TensorStorage epilogue; - }; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { - alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV; - alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q; - alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do; - alignas(16) typename TileScheduler::SharedStorage smem_scheduler; - } pipelines; - - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - return { - CollectiveMainloop::to_underlying_arguments(args.mainloop), - CollectiveEpilogue::to_underlying_arguments(args.epilogue), - hw_info, - TileScheduler::to_underlying_arguments(args.scheduler) - }; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - - static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - using PipelineParams = typename MainloopPipeline::Params; - using PipelineState = typename MainloopPipeline::PipelineState; - using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO; - using PipelineParams_dO = typename MainloopPipeline_dO::Params; - using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; - static constexpr bool Q_dO_same_stages = std::is_same_v; - - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int const lane_predicate = cute::elect_one_sync(); - int const warp_idx = cutlass::canonical_warp_idx_sync(); - - // Issue Tma Descriptor Prefetch from a single thread - if (warp_idx == 0 && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Obtain warp index - int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; - - PipelineParams pipeline_params; - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE; - int warp_group_idx = cutlass::canonical_warp_group_idx(); - pipeline_params.role = warp_group_idx == 0 - ? MainloopPipeline::ThreadCategory::Producer - : MainloopPipeline::ThreadCategory::Consumer; - pipeline_params.is_leader = warp_group_thread_idx == 0; - pipeline_params.num_consumers = NumMmaThreads; - - if (warp_idx == 0 && lane_predicate) { - shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/); - } - // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init(); - MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{}); - auto role_dO = warp_group_idx == 0 - ? MainloopPipeline_dO::ThreadCategory::Producer - : MainloopPipeline_dO::ThreadCategory::Consumer; - PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers}; - MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return(pipeline_params, pipeline_params_dO), ClusterShape{}); - - CollectiveMainloop mainloop; - CollectiveEpilogue epilogue; - - // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - cute::cluster_wait(); - } else { - __syncthreads(); - } - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); - - if (warp_group_idx == 0) { // Producer - cutlass::arch::warpgroup_reg_dealloc(); - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO - PipelineState smem_pipe_write = cutlass::make_producer_start_state(); - PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state(); - for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { - auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); - auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; - cute::tuple block_coord = {n_block, bidh, bidb}; - auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { - scheduler.prefetch_next_work(params.scheduler, work_tile_info); - }; - mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, - smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); - } - mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); - } else if (warp_idx_in_warpgroup == 1) { - for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { - auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); - auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; - cute::tuple block_coord = {n_block, bidh, bidb}; - mainloop.store_dq(params.mainloop, shared_storage, block_coord); - } - } - } else { // Consumer - cutlass::arch::warpgroup_reg_alloc(); - // Initialize matmul objects. - TiledMmadKV tiled_mma_dKV; - - PipelineState smem_pipe_read; - PipelineState_dO smem_pipe_read_do; - - mainloop.mma_init(); - scheduler.init_consumer(); - - int work_idx = 0; - CUTLASS_PRAGMA_NO_UNROLL - for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { - auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); - auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; - cute::tuple block_coord = {n_block, bidh, bidb}; - - // dK and dV output accumulator. - Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - bool tile_valid = mainloop.mma( - params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, - tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); - if (tile_valid) { - epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, - threadIdx.x - NumCopyThreads, block_coord); - } else { - epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); - } - - } - epilogue.store_tail(); - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_bwd_launch_template.h b/flash-attn/flash_bwd_launch_template.h deleted file mode 100644 index 76ded0407ecff487c376d0d16309c098ab9b2aa2..0000000000000000000000000000000000000000 --- a/flash-attn/flash_bwd_launch_template.h +++ /dev/null @@ -1,377 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include "cutlass/device_kernel.h" // For device_kernel -#include "cutlass/kernel_launch.h" // For kernel_launch -#include "cutlass/cluster_launch.hpp" // For ClusterLauncher - -#include "static_switch.h" -#include "flash.h" -#include "flash_bwd_preprocess_kernel.h" -#include "flash_bwd_postprocess_kernel.h" -#include "tile_scheduler.hpp" -#include "mainloop_bwd_sm90_tma_gmma_ws.hpp" -#include "mainloop_bwd_sm80.hpp" -#include "epilogue_bwd.hpp" -#include "flash_bwd_kernel_sm90.h" -#include "flash_bwd_kernel_sm80.h" - -using namespace cute; - -template -void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { - static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); - using ElementAccum = float; - using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; - - int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM); - int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN); - bool const is_varlen_q = params.cu_seqlens_q; - bool const is_varlen_k = params.cu_seqlens_k; - int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; - int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k; - int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded; - int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded; - int batch_q = !is_varlen_q ? params.b : 1; - int batch_k = !is_varlen_k ? params.b : 1; - - using TileShape_MK = cute::Shape, Int>; - using PreprocessKernel = flash::FlashAttnBwdPreprocess; - typename PreprocessKernel::Arguments preprocess_args { - static_cast(params.o_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_O - {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O - static_cast(params.do_ptr), - {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO - static_cast(params.dsoftmax_sum), - {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum - {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum - static_cast(params.softmax_lse_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE - static_cast(params.softmax_lse_log2_ptr), - {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 - static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum - {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum - params.b, - params.dq_semaphore, - params.cu_seqlens_q, - params.seqused_q - }; - typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); - int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); - dim3 grid_m(num_m_block, params.h, params.b); - cutlass::kernel_launch(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); - - using TileShape_MNK = cute::Shape, Int, Int>; - using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster - // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80 - static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80; - static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1; - using CollectiveMainloop = std::conditional_t< - Arch >= 90, - flash::CollectiveMainloopBwdSm90, - flash::CollectiveMainloopBwdSm80 - >; - using CollectiveEpilogue = std::conditional_t< - !GQA, - flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, - flash::CollectiveEpilogueBwdGQA - >; - using Scheduler = flash::SingleTileScheduler; - using AttnKernel = std::conditional_t< - Arch >= 90, - flash::enable_sm90_or_later>, - flash::enable_sm80_to_sm89> - >; - - typename CollectiveMainloop::Arguments mainloop_args { - static_cast(params.q_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_Q - {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q - static_cast(params.k_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_K - {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K - static_cast(params.v_ptr), - {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V - static_cast(params.do_ptr), - {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO - static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum - {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum - static_cast(params.softmax_lse_log2_ptr), - {seqlen_q_rounded, params.h, batch_q}, // shape_LSE - {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 - static_cast(params.dsoftmax_sum), - {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum - params.scale_softmax, - params.window_size_left, params.window_size_right, - params.softcap, - params.b, - params.dq_semaphore, - params.cu_seqlens_q, params.cu_seqlens_k, - params.seqused_q, params.seqused_k - }; - // The case work with GQA is ugly but idk how to fix it. - typename CollectiveEpilogue::Arguments epilogue_args { - static_cast(!GQA ? params.dk_ptr : params.dk_accum_ptr), - [&] { - if constexpr (!GQA) { - return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK - } else { - return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum - } - }(), - [&] { - if constexpr (!GQA) { - return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK - } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum - } - }(), - static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), - [&] { - if constexpr (!GQA) { - return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV - } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum - } - }(), - params.h, - params.dk_semaphore, - params.dv_semaphore, - params.cu_seqlens_k, - params.seqused_k, - }; - - int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{})); - num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{})); - typename flash::TileSchedulerArguments scheduler_args { - num_blocks_n, params.h, params.b, 1 /*num_splits*/, - params.h / params.h_k, - params.seqlen_k, - params.seqlen_q, params.d, params.dv, sizeof(Element), - params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k - }; - - int device; - cudaGetDevice(&device); - typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ - mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args - }); - - dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); - dim3 block_dims = AttnKernel::get_block_shape(); - int smem_size = AttnKernel::SharedStorageSize; - // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); - // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do)); - // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds)); - // int smem_size_dqacc = [&] { - // if constexpr (Arch >= 90) { - // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc)); - // } else { - // return 0; - // } - // }(); - // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); - // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); - // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse)); - // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum)); - // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum); - if constexpr (size(ClusterShape{}) > 1) { - void const* kernel = (void const*) cutlass::device_kernel; - if (smem_size >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); - cutlass::ClusterLauncher::launch( - grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/); - } else { - if (smem_size >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/); - } - CHECK_CUDA_KERNEL_LAUNCH(); - - using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ; - typename PostprocessKernel::Arguments postprocess_args { - static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum - {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum - static_cast(params.dq_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_dQ - {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ - params.scale_softmax, - params.cu_seqlens_q, - params.seqused_q - }; - typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); - int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); - dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b); - int smem_size_postprocess = PostprocessKernel::SharedStorageSize; - if (smem_size_postprocess >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); - } - cutlass::kernel_launch(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); - - if constexpr (GQA) { - using TileShape_NK = cute::Shape, Int>; - using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ; - typename PostprocessKerneldKV::Arguments postprocess_dK_args { - static_cast(params.dk_accum_ptr), - {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum - static_cast(params.dk_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK - {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK - 1.f, - params.cu_seqlens_k, - params.seqused_k - }; - typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); - typename PostprocessKerneldKV::Arguments postprocess_dV_args { - static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum - static_cast(params.dv_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV - {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV - 1.f, - params.cu_seqlens_k, - params.seqused_k - }; - typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args); - int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{})); - dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b); - int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize; - if (smem_size_postprocess >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); - } - cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); - cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); - } - -} - -template -void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) { - VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.h != params.h_k, GQA, [&] { -// BOOL_SWITCH(params.deterministic, Deterministic, [&] { - // run_flash_bwd(params, stream); - run_flash_bwd(params, stream); -// }); - }); - }); -} - - -template -void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { - if constexpr (Arch >= 90) { - if constexpr (Is_causal && Has_softcap) { - // register spill with 128 x 128 - run_mha_bwd_dispatch(params, stream); - } else { - // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block. - run_mha_bwd_dispatch(params, stream); - } - } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); - // run_mha_bwd_dispatch(params, stream); - // run_mha_bwd_dispatch(params, stream); - // run_mha_bwd_dispatch(params, stream); - } else { - run_mha_bwd_dispatch(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { - if constexpr (Arch >= 90) { - run_mha_bwd_dispatch(params, stream); - } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); - } else { - run_mha_bwd_dispatch(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { - if constexpr (Arch >= 90) { - if constexpr (Is_causal || Is_local || Has_softcap) { - run_mha_bwd_dispatch(params, stream); - } else { - run_mha_bwd_dispatch(params, stream); - } - } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); - } else { - run_mha_bwd_dispatch(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { - if constexpr (Arch >= 90) { - run_mha_bwd_dispatch(params, stream); - } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); - } else { - run_mha_bwd_dispatch(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { - if constexpr (Arch >= 90) { - run_mha_bwd_dispatch(params, stream); - } else if constexpr (Arch == 86 || Arch == 89) { - run_mha_bwd_dispatch(params, stream); - // run_mha_bwd_dispatch(params, stream); - } else { - run_mha_bwd_dispatch(params, stream); - } - }); -} diff --git a/flash-attn/flash_bwd_postprocess_kernel.h b/flash-attn/flash_bwd_postprocess_kernel.h deleted file mode 100644 index c91e261507dd906ebb457db3d3f5dd10ef4b5109..0000000000000000000000000000000000000000 --- a/flash-attn/flash_bwd_postprocess_kernel.h +++ /dev/null @@ -1,256 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include -#include "cutlass/arch/barrier.h" - -#include "seqlen.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnBwdPostprocessConvertdQ { - -public: - - // Type Aliases - using TileShape_MK = TileShape_MK_; - using ArchTag = ArchTag_; - - static_assert(ArchTag::kMinComputeCapability >= 75); - static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90; - - static constexpr uint32_t MaxThreadsPerBlock = kNThreads; - static constexpr uint32_t MinBlocksPerMultiprocessor = 2; - - static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kHeadDim = get<1>(TileShape_MK{}); - static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup"); - static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup; - using R2SLayoutAtomdQaccum = std::conditional_t< - IsSm90, - Layout, Int>>, - Layout>> - >; - using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, - Layout>>{})); // Val layout, 1 or 4 vals per read - using G2SLayoutAtomdQaccum = Layout>>; - // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions - using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, G2SLayoutAtomdQaccum{}, - Layout>{})); // Val layout, 4 vals per read - // We don't do bound checking for the gmem -> smem load so we just assert here. - static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0); - static constexpr int SmemdQaccumSize = size(TileShape_MK{}); - using SmemLayoutdQaccumFlat = Layout>>; - using SmemLayoutdQaccum = std::conditional_t< - IsSm90, - Layout, Int>>, - Layout>> - >; - - // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, - // then setting kBlockKSmem to 32 will cause "Static shape_div failure". - // We want to treat it as 64 x 48, so kBlockKSmem should be 16. - static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{}); - static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16); - static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); - using SmemLayoutAtomdQ = - decltype(composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{})); - using SmemLayoutdQt = - decltype(cute::composition(SmemLayoutdQ{}, - make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})), - make_stride(Int(TileShape_MK{})>{}, _1{})))); - - using SmemCopyAtomdQ = Copy_Atom< - std::conditional_t< - IsSm90, - std::conditional_t, - AutoVectorizingCopyWithAssumedAlignment<128> - >, - Element>; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock)); - static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per load - - struct SharedStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_dqacc; - cute::array_aligned> smem_dq; - alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) - using StridedQ = cute::Stride; - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) - using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; - - // Device side arguments - struct Arguments { - ElementAccum const* ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum const stride_dQaccum; - Element* ptr_dQ; - ShapedQ const shape_dQ; - StridedQ const stride_dQ; - float const softmax_scale; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - // Kernel entry point API - struct Params { - ElementAccum const* ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum const stride_dQaccum; - Element* ptr_dQ; - ShapedQ const shape_dQ; - StridedQ const stride_dQ; - float const softmax_scale; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args) { - return { - args.ptr_dQaccum, - args.shape_dQaccum, - args.stride_dQaccum, - args.ptr_dQ, - args.shape_dQ, - args.stride_dQ, - args.softmax_scale, - args.cu_seqlens, - args.seqused - }; - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - - static constexpr int kBlockM = get<0>(TileShape_MK{}); - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{}); - Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{}); - Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{}); - Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{}); - - int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; - int const bidh = blockIdx.y; - int const bidb = blockIdx.z; - - flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused); - bool const is_varlen = params.cu_seqlens; - if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; } - - // Step 1: load dQaccum from gmem to smem - Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), - params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) - if constexpr (IsSm90) { // Use BulkCopy - static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v / 8); - auto bulk_copy = Copy_Traits{}; - // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); } - if (thread_idx == 0) { - shared_storage.barrier_dQaccum.init(1 /*numThreads*/); - shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); - copy(bulk_copy.with(*reinterpret_cast(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat); - } - __syncthreads(); - shared_storage.barrier_dQaccum.wait(0); - } else { - G2STiledCopydQaccum g2s_tiled_copy_dQaccum; - auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); - Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum); - Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum); - cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s); - __syncthreads(); - } - - // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); } - - // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16 - R2STiledCopydQaccum s2r_tiled_copy_dQaccum; - auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx); - Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum); - TiledMma tiled_mma_dQ; - Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select(TileShape_MK{})); - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); } - CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); - Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); - cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); - #pragma unroll - for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; } - // Convert tdQrdQ from fp32 to fp16 - Tensor rdQ = make_tensor_like(taccdQrdQaccum); - flash::convert_type_out(taccdQrdQaccum, rdQ); - - // Step 3: Copy dQ from register to smem - auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - // if (cute::thread0()) { print(smem_tiled_copy_dQ); } - // if (cute::thread0()) { print(smem_thr_copy_dQ); } - // if (cute::thread0()) { print(sdQ); } - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - __syncthreads(); - - // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem - Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) - GmemTiledCopy gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx); - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - - Tensor tdQrdQ = make_fragment_like(tdQsdQ); - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{})); - Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); - #pragma unroll - for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); } - // Need to check OOB when reading from smem if kBlockM isn't evenly tiled - static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; - flash::copy( - gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM); - - // Step 5: Copy dQ from register to gmem - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM) - ); - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_bwd_preprocess_kernel.h b/flash-attn/flash_bwd_preprocess_kernel.h deleted file mode 100644 index 85e877f9d4fd2821ef38f78b12f1e1fdde2b6d37..0000000000000000000000000000000000000000 --- a/flash-attn/flash_bwd_preprocess_kernel.h +++ /dev/null @@ -1,252 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include - -#include "seqlen.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnBwdPreprocess { - -public: - - // Type Aliases - using TileShape_MK = TileShape_MK_; - using ArchTag = ArchTag_; - - static_assert(std::is_same_v && ArchTag::kMinComputeCapability >= 75 || - std::is_same_v && ArchTag::kMinComputeCapability >= 80 || - std::is_same_v && ArchTag::kMinComputeCapability >= 89); - - static constexpr uint32_t MaxThreadsPerBlock = 256; - static constexpr uint32_t MinBlocksPerMultiprocessor = 2; - static constexpr int SharedStorageSize = 0; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kHeadDim = get<1>(TileShape_MK{}); - // We want kBlockKGmem to be a power of 2 so that when we do the summing, - // it's just between threads in the same warp - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per load - - static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); - static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum"); - using GmemLayoutAtomAccum = Layout>>; - using GmemTiledCopyAccum = decltype( - make_tiled_copy(Copy_Atom, ElementAccum>{}, - GmemLayoutAtomAccum{}, - Layout>>{})); // Val layout, 4 vals per store - - using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) - using StrideO = cute::Stride; - using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) - using StridedPsum = cute::Stride<_1, int64_t, int64_t>; - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) - using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; - - // Device side arguments - struct Arguments { - Element const* ptr_O; - ShapeO const shape_O; - StrideO const stride_O; - Element const* ptr_dO; - StrideO const stride_dO; - float* ptr_dPsum; - ShapedPsum const shape_dPsum; - StridedPsum const stride_dPsum; - float const* ptr_LSE; - StridedPsum const stride_LSE; - float *ptr_LSE_log2; - StridedPsum const stride_LSE_log2; - ElementAccum* ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum const stride_dQaccum; - int num_batch; // We need this to know the size of dq_semaphore in case of varlen - int* dq_semaphore; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - // Kernel entry point API - struct Params { - Element const* ptr_O; - ShapeO const shape_O; - StrideO const stride_O; - Element const* ptr_dO; - StrideO const stride_dO; - float* ptr_dPsum; - ShapedPsum const shape_dPsum; - StridedPsum const stride_dPsum; - float const* ptr_LSE; - StridedPsum const stride_LSE; - float* ptr_LSE_log2; - StridedPsum const stride_LSE_log2; - ElementAccum* ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum const stride_dQaccum; - int num_batch; - int* dq_semaphore; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - }; - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args) { - return { - args.ptr_O, - args.shape_O, - args.stride_O, - args.ptr_dO, - args.stride_dO, - args.ptr_dPsum, - args.shape_dPsum, - args.stride_dPsum, - args.ptr_LSE, - args.stride_LSE, - args.ptr_LSE_log2, - args.stride_LSE_log2, - args.ptr_dQaccum, - args.shape_dQaccum, - args.stride_dQaccum, - args.num_batch, - args.dq_semaphore, - args.cu_seqlens, - args.seqused - }; - } - - CUTLASS_DEVICE - void - operator()(Params const& params, [[maybe_unused]] char* smem_buf) { - - static constexpr int kBlockM = get<0>(TileShape_MK{}); - - int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; - int const bidh = blockIdx.y; - int const bidb = blockIdx.z; - - flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused); - bool const is_varlen = Varlen && params.cu_seqlens; - int const seqlen_o = seqlen_info.seqlen; - if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } - - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) - Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0); - Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) - - auto shape_LSE = select<0, 2, 3>(params.shape_O); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0); - Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape>{}, make_coord(m_block)); - static_assert(kBlockM <= MaxThreadsPerBlock); - float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY; - - GmemTiledCopy gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - - Tensor tOgO = gmem_thr_copy_O.partition_S(gO); - Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO); - // Construct identity layout for gO - Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - - // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128) - Tensor tOrO = make_fragment_like(tOgO); - Tensor tOrdO = make_fragment_like(tOgdO); - flash::copy( - gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - flash::copy( - gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));} - - // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64)) - Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout()))); - Tensor tOrO_l = make_tensor(tOrO.data(), l); - Tensor o_fp32 = make_tensor_like(tOrO_l); - flash::convert_type_out(tOrO_l, o_fp32); - Tensor tOrdO_l = make_tensor(tOrdO.data(), l); - Tensor do_fp32 = make_tensor_like(tOrdO_l); - flash::convert_type_out(tOrdO_l, do_fp32); - // Sum across the last dimension - Tensor dP_sum = make_tensor(make_shape(size<0>(o_fp32))); - #pragma unroll - for (int mi = 0; mi < size<0>(o_fp32); ++mi) { - float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); - #pragma unroll - for (int ni = 1; ni < size<1>(o_fp32); ni++) { - dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); - } - flash::SumOp sum_op; - dP_sum(mi) = flash::Allreduce::run(dP_sum_cur, sum_op); - } - - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape>{}, make_coord(m_block)); - if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) { - #pragma unroll - for (int mi = 0; mi < size(dP_sum); ++mi) { - int const row = get<0>(tOcO(_0{}, mi, _0{})); - gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0; - } - } - - int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM); - Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0); - Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape>{}, make_coord(m_block)); - if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) { - gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); - } - - if constexpr (Clear_dQaccum) { - Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); - GmemTiledCopyAccum gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - Tensor zero = make_fragment_like(tdQgdQaccum); - clear(zero); - cute::copy(Copy_Atom, ElementAccum>{}, zero, tdQgdQaccum); - } - - if (params.dq_semaphore != nullptr && thread_idx == 0) { - int const num_batch = params.num_batch; - int const num_head = get<2>(params.shape_O); - params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0; - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_fwd_combine.cu b/flash-attn/flash_fwd_combine.cu deleted file mode 100644 index 3e85a0a212cb831a62363dac16f5016f27271934..0000000000000000000000000000000000000000 --- a/flash-attn/flash_fwd_combine.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_combine_launch_template.h" - -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); - -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); - -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/flash-attn/flash_fwd_combine_kernel.h b/flash-attn/flash_fwd_combine_kernel.h deleted file mode 100644 index 3aa5484cbca404e9c375d06f758be8d5b4eac05d..0000000000000000000000000000000000000000 --- a/flash-attn/flash_fwd_combine_kernel.h +++ /dev/null @@ -1,702 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include -#include - -#include "cutlass/arch/grid_dependency_control.h" - -#include "seqlen.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnFwdCombine { - -public: - - // Type Aliases - using TileShape_MK = TileShape_MK_; - using ArchTag = ArchTag_; - static constexpr int kMaxSplits = 1 << kLogMaxSplits_; - static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float))); - static_assert(AlignmentLSE >= 1); - static constexpr int kStages = 4; - - static_assert(ArchTag::kMinComputeCapability >= 75); - static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; - - static constexpr uint32_t MaxThreadsPerBlock = kNThreads; - static constexpr uint32_t MinBlocksPerMultiprocessor = 2; - - static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kBlockK = get<1>(TileShape_MK{}); - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); - static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); - static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); - using GmemCopyAtom = std::conditional_t< - Has_cp_async, - cute::Copy_Atom, ElementPartial>, - cute::Copy_Atom, ElementPartial> - >; - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); - using GmemTiledCopyAccum = decltype( - make_tiled_copy(GmemCopyAtom{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 4 vals per load - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 4 vals per load - - using AlignmentTypeLSE = cute::uint_byte_t(sizeof(float)) * AlignmentLSE>; - static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float); - static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE"); - static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8"); - static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8))); - static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE; - static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE"); - using GmemLayoutAtomLSE = Layout, Int>, - Stride, _1>>; - static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0); - using GmemCopyAtomLSE = std::conditional_t< - Has_cp_async, - cute::Copy_Atom, float>, - cute::Copy_Atom, float> - >; - using GmemTiledCopyLSE = decltype( - make_tiled_copy(GmemCopyAtomLSE{}, - GmemLayoutAtomLSE{}, - Layout>>{})); // Val layout, 4 vals per load - - // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking - static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE"); - // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts - using SmemLSESwizzle = std::conditional_t< - kBlockMSmem == 8, - Swizzle<5, 0, 5>, - std::conditional_t, Swizzle<3, 2, 3>> - >; - using SmemLayoutAtomLSE = - decltype(composition(SmemLSESwizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); - - using SmemLayoutO = Layout, Int, Int>, - Stride, _1, Int>>; - - // We want each column (kMaxSplits) to be processed by threads in the same warp. - // To reduce the number of shuffles, we want as few threads on the same column as possible. - // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column - // have have 64 such quads. - static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem"); - static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem; - static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp"); - using S2RLayoutAtomLSE = Layout, Int>>; - using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom{}, S2RLayoutAtomLSE{}, Layout<_1>{})); - - using ShapeOPartial = cute::Shape; // (seqlen, d, num_splits, head, batch) - using StrideOPartial = cute::Stride; - using ShapeLSEPartial = cute::Shape; // (seqlen, num_splits, head, batch) - using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch) - using ShapeO = cute::Shape; // (seqlen, d, head, batch) - using StrideO = cute::Stride; - using ShapeLSE = cute::Shape; // (seqlen, head, batch) - using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) - - struct BlockCoord { - int block_m; - int block_k; - int bidb; - }; - - struct SharedStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_lse_partial; - cute::array_aligned smem_max_valid_split; - cute::array_aligned> smem_o_partial; - BlockCoord block_coord; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - int b; - ElementPartial const* const ptr_O_partial; - ShapeOPartial const shape_O_partial; - StrideOPartial const stride_O_partial; - float const* const ptr_LSE_partial; - ShapeLSEPartial const shape_LSE_partial; - StrideLSEPartial const stride_LSE_partial; - Element* const ptr_O; - StrideO const stride_O; - float* const ptr_LSE; - StrideLSE const stride_LSE; - int const* const cu_seqlens = nullptr; - int const* const seqused = nullptr; - int const* const num_splits_dynamic_ptr = nullptr; - int* const semaphore_to_reset = nullptr; - }; - - // Kernel entry point API - struct CollectiveParams { - int b; - ElementPartial const* const ptr_O_partial; - ShapeOPartial const shape_O_partial; - StrideOPartial const stride_O_partial; - float const* const ptr_LSE_partial; - ShapeLSEPartial const shape_LSE_partial; - StrideLSEPartial const stride_LSE_partial; - Element* const ptr_O; - StrideO const stride_O; - float* const ptr_LSE; - StrideLSE const stride_LSE; - cutlass::FastDivmod seqlen_divmod, head_divmod; - int const* const cu_seqlens = nullptr; - int const* const seqused = nullptr; - int const* const num_splits_dynamic_ptr = nullptr; - int* const semaphore_to_reset = nullptr; - }; - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - CollectiveParams - to_underlying_arguments(Arguments const& args) { - assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); - return { - args.b, - args.ptr_O_partial, - args.shape_O_partial, - args.stride_O_partial, - args.ptr_LSE_partial, - args.shape_LSE_partial, - args.stride_LSE_partial, - args.ptr_O, - args.stride_O, - args.ptr_LSE, - args.stride_LSE, - cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), - args.cu_seqlens, - args.seqused, - args.num_splits_dynamic_ptr, - args.semaphore_to_reset - }; - } - - struct SchedulerArguments { - int b; - int seqlen_q; - int total_q; - int num_heads; - int dv; - int const* cu_seqlens_q; - int const* seqused_q; - }; - - struct StaticTileScheduler { - struct Params {}; - static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; } - - SharedStorage& shared_storage; - CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {} - - static dim3 get_grid_shape(SchedulerArguments const& args) { - unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); - unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM); - return {num_blocks_m, num_blocks_k, static_cast(args.b)}; - } - - CUTE_DEVICE BlockCoord get_block_coord(Params const& params) { - int block_m = blockIdx.x; - int block_k = blockIdx.y; - int bidb = blockIdx.z; - return {block_m, block_k, bidb}; - } - }; - - struct StaticVarlenTileScheduler { - // - // For varlen we have two Scheduling algos: - // 1) STANDARD, same as StaticTileScheduler - // 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and - // batch dimension into a linear tile index. The grid is then a - // 2D grid of (tile_id, k_block). We then map the linear tile id - // to (m_block, bidb) in the get_block_coord function. This mapping - // is non-trivial since each batch element can have a different - // number of m_blocks. This has overhead when computing the block - // coordinates, but it is more efficient when prefills and decodes - // are mixed since in that case the STANDARD scheduling algo will - // have a lot of empty (no work) blocks in the grid. - // - - enum SchedulingAlgo { - STANDARD, // Same as StaticTileScheduler - LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index - }; - - struct Params { - int b; - int num_heads; - int const* const cu_seqlens_q; - int const* const seqused_q; - SchedulingAlgo algo; - }; - - SharedStorage& shared_storage; - CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {} - - static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) { - // Choose the scheduling algorithm based on how dense the grid of tiles that - // do actual work is. If the grid is more then 50% sparse, we linearize the M - // and batch. If the grid is more than 50% dense, we use the standard scheduling - // algorithm since its more efficient at calculating the block coordinates. - // NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches - // use lower bound to estimate when the density is more than 50% - int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM); - int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM); - return 2 * lower_bound_on_non_empty_tiles >= grid_size ? - SchedulingAlgo::STANDARD : - SchedulingAlgo::LINEARIZE_M_AND_BATCH; - } - - static Params to_underlying_arguments(SchedulerArguments const& args) { - return { - args.b, - args.num_heads, - args.cu_seqlens_q, - args.seqused_q, - choose_scheduling_algo(args) - }; - } - - static dim3 get_grid_shape(SchedulerArguments const& args) { - unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); - - switch (choose_scheduling_algo(args)) { - case SchedulingAlgo::STANDARD: { - unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); - unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM); - return {num_blocks_m, num_blocks_k, static_cast(args.b)}; - } - case SchedulingAlgo::LINEARIZE_M_AND_BATCH: { - // rough worst case upper bound on the number of blocks required - // (assuming each batch has an additional partial block) - unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b; - return {num_blocks_m, num_blocks_k, 1}; - }} - - // rough worst case upper bound on the number of blocks required - // (assuming each batch has an additional partial block) - unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b; - return {num_blocks_m, num_blocks_k, 1}; - } - - CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) { - int num_heads = params.num_heads; - int curr_tile_id = blockIdx.x; - - // Scan through the batches find the batch that contains the current - // tile_id. Compute using only the first warp of the block. - if (threadIdx.x < 32) { - // We compute linearized tile index start and ends for each batch - // in groups of 32 in parallel - int group_start_bidb = -(cutlass::NumThreadsPerWarp); - int group_end_bidb = 0; - int group_end_tile_id = 0; - int group_start_tile_id = 0; - int group_total_num_tiles = 0; - - int local_num_m_blocks = 0; - int local_num_m_blocks_cumulative = 0; - - do { - group_start_bidb += cutlass::NumThreadsPerWarp; - group_end_bidb += cutlass::NumThreadsPerWarp; - - auto get_num_m_blocks = [&](int bidb) { - if (bidb >= params.b) return 0; - flash::SeqlenInfo seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q}; - return cute::ceil_div(seqlen_info.seqlen * num_heads, Int{}()); - }; - - // Cumulative number of blocks for the next 31 batches - local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x); - local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks); - // Total number of blocks for the next 32 batches - group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative); - - group_start_tile_id = group_end_tile_id; - group_end_tile_id += group_total_num_tiles; - } while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b); - - int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative; - // Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id` - // these values below are now common to all threads in the warp - int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id); - int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group); - int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ? - warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0); - - int bidb = group_start_bidb + batch_idx_in_group; - int block_m = curr_tile_id - batch_m_start_tile_id; - // NOTE(lucas): not sure why this causes a block_k unused warning - // just inlined `blockIdx.y` to suppress the warning - // int block_k = blockIdx.y; - // shared_storage.block_coord = {block_m, block_k, bidb}; - BlockCoord block_coord{block_m, static_cast(blockIdx.y), bidb}; - if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; } - } - - __syncthreads(); - return shared_storage.block_coord; - } - - - CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) { - int block_m = blockIdx.x; - int block_k = blockIdx.y; - int bidb = blockIdx.z; - return {block_m, block_k, bidb}; - } - - CUTE_DEVICE BlockCoord get_block_coord(Params const& params) { - switch (params.algo) { - case SchedulingAlgo::STANDARD: - return get_block_coord_standard(params); - case SchedulingAlgo::LINEARIZE_M_AND_BATCH: - return get_block_coord_linearized_m_and_batch(params); - } - return {0, 0, 0}; // Should never reach here - } - }; - - using TileScheduler = std::conditional_t< - Varlen, - StaticVarlenTileScheduler, - StaticTileScheduler - >; - - using SchedulerParams = typename TileScheduler::Params; - - struct Params { - CollectiveParams params; - SchedulerParams scheduler_params; - }; - - CUTLASS_DEVICE - void - operator()(Params const& kernel_params, char* smem_buf) { - CollectiveParams const& params = kernel_params.params; - - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - TileScheduler tile_scheduler{shared_storage}; - - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{}); - Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape>{}); - Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); - - int const thread_idx = threadIdx.x; - - BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params); - - int const m_block = block_coord.block_m; - int const k_block = block_coord.block_k; - int const batch = block_coord.bidb; - - if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { - cutlass::arch::wait_on_dependent_grids(); - *params.semaphore_to_reset = 0; - } - - flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; - int const offset = seqlen_info.offset; - int const seqlen = seqlen_info.seqlen; - int max_idx = seqlen * get<2>(params.shape_LSE_partial); - - bool block_coord_valid = - block_coord.block_m < cute::ceil_div(max_idx, Int{}) && - block_coord.bidb < params.b; - if (!block_coord_valid) { return; } - - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); - if (num_splits <= 1) { return; } - - cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); - - // Step 1: load LSE_partial from gmem -> smem - Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), - select<1, 0, 2, 3>(params.shape_LSE_partial), - select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) - Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); - GmemTiledCopyLSE gmem_tiled_copy_LSE; - auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); - Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE); - - // Construct identity layout for sLSE - Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m) - // Repeat the partitioning with identity layouts - Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); - - cutlass::arch::wait_on_dependent_grids(); - - #pragma unroll - for (int m = 0; m < size<2>(tLSEcLSE); ++m) { - int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; - if (idx < max_idx) { - int m_idx, bidh; - if constexpr (!Varlen) { - bidh = params.seqlen_divmod.divmod(m_idx, idx); - } else { - bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - } - Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); - #pragma unroll - for (int s = 0; s < size<1>(tLSEcLSE); ++s) { - int si = get<0>(tLSEcLSE(_0{}, s, _0{})); - // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} - if (si < num_splits) { - cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); - } else { - cute::fill(tLSEsLSE(_, s, m), -INFINITY); - } - } - } else { - // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem - // cute::fill(tLSEsLSE(_, _, m), -INFINITY); - } - } - if constexpr (Has_cp_async) { cute::cp_async_fence(); } - - // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2. - // We want these async loads to be in flight as we compute the LSE. - GmemTiledCopyAccum gmem_tiled_copy_O_partial; - auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx); - // Construct identity layout for gO - Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), - params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) - - // Precompute these values to avoid recomputing them in the loop - Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); - Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); - Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); - #pragma unroll - for (int m = 0; m < size<1>(tOcO); ++m) { - int mi = get<0>(tOcO(_0{}, m, _0{})); - int idx = m_block * kBlockM + mi; - if constexpr (!Varlen) { - tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); - } else { - tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); - } - tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); - if (idx >= max_idx) { - tObidh[m] = -1; - } - } - - Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); - if constexpr (!(Is_even_K)) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } - } - - Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); - - auto load_O_partial = [&] (int split, int stage) { - Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); - #pragma unroll - for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidh(m) >= 0) { - Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); - Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); - #pragma unroll - for (int k = 0; k < size<2>(tOcO); ++k) { - int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (Is_even_K || tOpO(k)) { - cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k)); - } - } - } - } - }; - - for (int s = 0; s < kStages - 1; ++s) { - if (s < num_splits) { load_O_partial(s, s); } - if constexpr (Has_cp_async) { cute::cp_async_fence(); } - } - - // Step 3: load and transpose LSE_partial from smem -> rmem - if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } - __syncthreads(); - - S2RTiledCopyLSE s2r_tiled_copy_LSE; - auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx); - Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE); - Tensor ts2rrLSE = make_fragment_like(ts2rsLSE); - cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE); - - // Step 4: compute the final LSE along the split dimension - Tensor lse_sum = make_tensor(make_shape(size<2>(ts2rrLSE))); - Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE); - // We compute the max valid split for each row to short-circuit the computation later - Tensor max_valid_split = make_tensor(make_shape(size<2>(ts2rrLSE))); - static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1); - #pragma unroll - for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - float lse_max = ts2rrLSE(_0{}, _0{}, m); - #pragma unroll - for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); } - MaxOp max_op; - lse_max = Allreduce::run(lse_max, max_op); - int max_valid_idx = -1; - #pragma unroll - for (int s = 0; s < size<1>(ts2rrLSE); ++s) { - if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); } - } - MaxOp max_int_op; - max_valid_split[m] = Allreduce::run(max_valid_idx, max_int_op); - float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum_cur = 0.f; - #pragma unroll - for (int s = 0; s < size<1>(ts2rrLSE); ++s) { - float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur); - lse_sum_cur += scale; - // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);} - // ts2rsLSE(_0{}, m, s) = scale; - ts2rrLSE(_0{}, s, m) = scale; - } - SumOp sum_op; - lse_sum_cur = Allreduce::run(lse_sum_cur, sum_op); - lse_sum(m) = logf(lse_sum_cur) + lse_max; - float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur; - #pragma unroll - for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; } - } - // Store the scales exp(lse - lse_logsum) back to smem - cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); - - // Store max_valid_split to smem - #pragma unroll - for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem - int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } - } - } - - // Step 5: store final LSE back to gmem - if (k_block == 0) { - auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); - #pragma unroll - for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem - int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; - if (idx < max_idx) { - int m_idx, bidh; - if constexpr (!Varlen) { - bidh = params.seqlen_divmod.divmod(m_idx, idx); - } else { - bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - } - // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh) = lse_sum(m); - } - } - } - } - - // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O - __syncthreads(); - int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))]; - #pragma unroll - for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); } - Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor(TileShape_MK{})).layout(); - Tensor tOrOpartial = make_fragment_like(tOrOpartial_layout); - Tensor tOrO = make_fragment_like(tOrOpartial); - clear(tOrO); - int stage_load = kStages - 1, stage_compute = 0; - #pragma unroll 4 // Already tuned for speed - for (int s = 0; s <= thr_max_valid_split; ++s) { - Tensor scale = make_tensor(make_shape(size<1>(tOrOpartial))); - #pragma unroll - for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); } - - if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); } - if constexpr (Has_cp_async) { cute::cp_async_fence(); } - stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0; - if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } - // We don't need __syncthreads() because each thread is just reading its own data from smem - cute::copy(Copy_Atom, ElementPartial>{}, - tOsOpartial(_, _, _, stage_compute), tOrOpartial); - stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0; - - #pragma unroll - for (int m = 0; m < size<1>(tOrOpartial); ++m) { - if (tObidh(m) >= 0 && scale(m) > 0.f) { - #pragma unroll - for (int k = 0; k < size<2>(tOrOpartial); ++k) { - if (Is_even_K || tOpO(k)) { - Tensor rOpartial = make_tensor_like(tOrOpartial(_, m, k)); - flash::convert_type_out(tOrOpartial(_, m, k), rOpartial); - #pragma unroll - for (int i = 0; i < size<0>(tOrOpartial); ++i) { - tOrO(i, m, k) += scale(m) * rOpartial[i]; - } - } - } - } - } - } - - // Step 7: Write the final O to gmem - Tensor rO = make_tensor_like(tOrO); - flash::convert_type_out(tOrO, rO); - auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), - shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); - Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); - GmemTiledCopy gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - - #pragma unroll - for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidh(m) >= 0) { - #pragma unroll - for (int k = 0; k < size<2>(tOcO); ++k) { - int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (Is_even_K || tOpO(k)) { - cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); - } - } - } - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_fwd_combine_launch_template.h b/flash-attn/flash_fwd_combine_launch_template.h deleted file mode 100644 index c99efdadfebf87ae7d70ca987fd90fea8184c299..0000000000000000000000000000000000000000 --- a/flash-attn/flash_fwd_combine_launch_template.h +++ /dev/null @@ -1,88 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 -#include "cutlass/device_kernel.h" // For device_kernel -#include "cutlass/kernel_launch.h" // For kernel_launch - -#include "static_switch.h" -#include "flash.h" -#include "flash_fwd_combine_kernel.h" - -using namespace cute; - -template -void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { - using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; - using TileShape_MK = cute::Shape, Int>; - using CombineKernel = flash::FlashAttnFwdCombine; - - typename CombineKernel::Arguments args { - params.b, - static_cast(params.oaccum_ptr), - {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial - {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial - static_cast(params.softmax_lseaccum_ptr), - {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial - {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial - static_cast(params.o_ptr), - {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O - static_cast(params.softmax_lse_ptr), - {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore - }; - - typename CombineKernel::SchedulerArguments scheduler_args { - params.b, params.seqlen_q, params.total_q, params.h, params.dv, - params.cu_seqlens_q, params.seqused_q - }; - - typename CombineKernel::Params kernel_params = { - CombineKernel::to_underlying_arguments(args), - CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args) - }; - - dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args); - auto kernel = cutlass::device_kernel; - int smem_size = CombineKernel::SharedStorageSize; - if (smem_size >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // kernel<<>>(kernel_params); - cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { - // We want kBlockM to be as small as possible to maximize parallelism. - // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); - static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); - ARCH_SWITCH(params.arch, Arch, [&] { - BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { - if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. - if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream, enable_pdl); - return; - } - } - if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream, enable_pdl); - } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream, enable_pdl); - } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream, enable_pdl); - } else { - run_flash_fwd_combine(params, stream, enable_pdl); - } - }); - }); -} diff --git a/flash-attn/flash_fwd_kernel_sm80.h b/flash-attn/flash_fwd_kernel_sm80.h deleted file mode 100644 index b308d2d1b8892397db52434df9da4d4ab6cf5c56..0000000000000000000000000000000000000000 --- a/flash-attn/flash_fwd_kernel_sm80.h +++ /dev/null @@ -1,215 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include - -#include "seqlen.h" -#include "utils.h" -#include "softmax.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnFwdSm80 { - -public: - - // Type Aliases - using CollectiveMainloop = CollectiveMainloop_; - using CollectiveEpilogue = CollectiveEpilogue_; - static constexpr bool Is_causal = CollectiveMainloop::Is_causal; - static constexpr bool Is_local = CollectiveMainloop::Is_local; - static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); - static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; - static constexpr bool Varlen = CollectiveMainloop::Varlen; - static constexpr bool PagedKV = CollectiveMainloop::PagedKV; - static constexpr bool Split = CollectiveMainloop::Split; - static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; - static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; - static constexpr bool AppendKV = CollectiveMainloop::AppendKV; - static constexpr bool PackGQA = CollectiveMainloop::PackGQA; - static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; - using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; - - // Mainloop derived types - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - // Epilogue derived types - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 80); - - using TileScheduler = TileScheduler_; - using TileSchedulerArguments = typename flash::TileSchedulerArguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{})); - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})); - static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1; - - // Kernel level shared memory storage - // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q - // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k). - static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))) - - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k))); - static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { - union { - struct { - cute::array padding_; - typename CollectiveMainloop::TensorStorage mainloop; - }; - // We want smem_o to line up with the start of smem_v - typename CollectiveEpilogue::TensorStorage epilogue; - }; - } tensors; - - alignas(16) typename TileScheduler::SharedStorage smem_scheduler; - - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - return { - CollectiveMainloop::to_underlying_arguments(args.mainloop), - CollectiveEpilogue::to_underlying_arguments(args.epilogue), - hw_info, - TileScheduler::to_underlying_arguments(args.scheduler) - }; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - CollectiveMainloop mainloop; - CollectiveEpilogue epilogue; - - TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); - // Initialize matmul objects. - TiledMma tiled_mma; - - scheduler.init_consumer(); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - CUTLASS_PRAGMA_NO_UNROLL - for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { - // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{})); - float softmax_scale_log2 = params.mainloop.softmax_scale_log2; - // If there's tanh softcap, the scaling will be done before tanh. - auto block_coord = work_tile_info.get_block_coord(params.scheduler); - int const bidb = get<2>(block_coord); - if constexpr (Is_FP8 && !Has_softcap) { - int const bidh = get<1>(block_coord); - int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; - float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; - float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; - softmax_scale_log2 *= q_descale * k_descale; - } - flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); - - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), - get<0>(params.mainloop.shape_K_new), - params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, - params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary - }; - if constexpr (AppendKV) { - bool tile_new_valid = mainloop.store_kv_new( - params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); - if (tile_new_valid) { __syncthreads(); } - } - bool tile_valid = mainloop.mma( - params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, - shared_storage); - scheduler.prefetch_next_work(params.scheduler, work_tile_info); - if (tile_valid) { - // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, - threadIdx.x, block_coord); - } else { - // Write 0 to gO and -inf to gLSE. - epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); - } - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_fwd_kernel_sm90.h b/flash-attn/flash_fwd_kernel_sm90.h deleted file mode 100644 index 242da9bf8a48d50642c0cb7e25cfc2ef65a4ec55..0000000000000000000000000000000000000000 --- a/flash-attn/flash_fwd_kernel_sm90.h +++ /dev/null @@ -1,468 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include -#include -#include -#include -#include -#include -#include "cutlass/pipeline/pipeline.hpp" - -#include "cutlass/arch/grid_dependency_control.h" - -#include "seqlen.h" -#include "utils.h" -#include "softmax.h" - -namespace flash { - -using namespace cute; - -template -class FlashAttnFwdSm90 { - -public: - - // Type Aliases - using CollectiveMainloop = CollectiveMainloop_; - using CollectiveEpilogue = CollectiveEpilogue_; - static constexpr bool Is_causal = CollectiveMainloop::Is_causal; - static constexpr bool Is_local = CollectiveMainloop::Is_local; - static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); - static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; - static constexpr bool Varlen = CollectiveMainloop::Varlen; - static constexpr bool Split = CollectiveMainloop::Split; - static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; - static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; - static constexpr bool AppendKV = CollectiveMainloop::AppendKV; - static constexpr bool HasQv = CollectiveMainloop::HasQv; - static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; - static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; - static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; - static constexpr bool PackGQA = CollectiveMainloop::PackGQA; - static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; - static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; - static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; - static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); - using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; - - using SmemLayoutSAux = typename CollectiveMainloop::SmemLayoutSAux; - - // Mainloop derived types - using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; - using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ClusterShape = typename CollectiveMainloop::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - using BarrierQ = std::conditional_t; - - // Epilogue derived types - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - using TileScheduler = TileScheduler_; - using TileSchedulerArguments = typename flash::TileSchedulerArguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); - - /// Register requirement for Load and Math WGs - // If we use cp.async to load K and V, we need more registers for the producer WG. - static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32); - static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160); - // If you want to print from the producer warp, you'd need to increase the number of registers - // Otherwise you'll get CUDA error. - // static constexpr uint32_t LoadRegisterRequirement = 40; - // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; - - // Kernel level shared memory storage - // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v - // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v). - static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); - static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; - struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128, _1> { - union { - struct { - cute::array padding_; - typename CollectiveMainloop::TensorStorage mainloop; - }; - // We want smem_o to line up with the start of smem_v - typename CollectiveEpilogue::TensorStorage epilogue; - }; - } tensors; - struct PipelineStorage : cute::aligned_struct<16, _1> { - alignas(16) BarrierQ barrier_Q; - alignas(16) BarrierQ barrier_Qv; - alignas(16) cutlass::arch::ClusterBarrier barrier_O; - alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; - alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; - alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt; - alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new; - alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new; - alignas(16) typename TileScheduler::SharedStorage smem_scheduler; - } pipelines; - - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments { - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params { - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - cutlass::KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args) { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) { - CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - return { - CollectiveMainloop::to_underlying_arguments(args.mainloop), - CollectiveEpilogue::to_underlying_arguments(args.epilogue), - hw_info, - TileScheduler::to_underlying_arguments(args.scheduler) - }; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - - static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); - - using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; - using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; - using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt; - using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew; - using PipelineState = typename CollectiveMainloop::PipelineState; - using PipelineParamsK = typename MainloopPipelineK::Params; - using PipelineParamsV = typename MainloopPipelineV::Params; - using PipelineParamsVt = typename MainloopPipelineVt::Params; - using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params; - - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int const lane_predicate = cute::elect_one_sync(); - int const warp_idx = cutlass::canonical_warp_idx_sync(); - - // Issue Tma Descriptor Prefetch from a single thread - if (warp_idx == 0 && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Obtain warp index - int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; - int warp_group_idx = cutlass::canonical_warp_group_idx(); - - if (warp_idx == 0 && lane_predicate) { - shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); - if constexpr (HasQv) { - shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); - } - shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); - } - - // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); - PipelineParamsK pipeline_params_k; - pipeline_params_k.role = warp_group_idx == 0 - ? MainloopPipelineK::ThreadCategory::Producer - : MainloopPipelineK::ThreadCategory::Consumer; - if constexpr (Use_TMA_KV) { - pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; - pipeline_params_k.is_leader = warp_group_thread_idx == 0; - pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; - } else { - pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; - pipeline_params_k.producer_arv_count = NumProducerThreads; - } - - static_assert(is_same_v); - PipelineParamsVt pipeline_params_vt = pipeline_params_k; - if constexpr (Use_TMA_KV && !SameHeadDim) { - pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; - if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } - } else { - if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } - } - - MainloopPipelineK pipeline_k = [&] { - if constexpr (Use_TMA_KV) { - return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); - } else { - return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k); - } - }(); - // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); - MainloopPipelineV pipeline_v = [&] { - if constexpr (!Transpose_V) { - static_assert(is_same_v); - if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); - } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); - } - } else { - PipelineParamsV pipeline_params_v; - pipeline_params_v.role = warp_group_idx == 0 - ? MainloopPipelineV::ThreadCategory::Producer - : MainloopPipelineV::ThreadCategory::Consumer; - pipeline_params_v.producer_arv_count = NumProducerThreads; - pipeline_params_v.consumer_arv_count = NumMmaThreads; - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); - } - }(); - // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then - // the producer WG will read from pipeline_vt and write to pipeline_v. - // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. - // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers. - // However, the thread role isn't used in the pipeline implementation. - MainloopPipelineVt pipeline_vt = [&] { - if constexpr (Use_TMA_KV) { - pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); - } else { - pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); - } - }(); - - PipelineParamsKVNew pipeline_params_kv_new; - pipeline_params_kv_new.role = warp_group_idx == 0 - ? MainloopPipelineKVNew::ThreadCategory::Producer - : MainloopPipelineKVNew::ThreadCategory::Consumer; - pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; - pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; - pipeline_params_kv_new.num_consumers = NumMmaThreads; - auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); - if constexpr (!SameHeadDim) { - pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; - } - auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); - - CollectiveMainloop mainloop; - CollectiveEpilogue epilogue; - - const int num_heads = get<2>(params.mainloop.shape_Q); - Tensor gS_aux = make_tensor(make_gmem_ptr(params.mainloop.ptr_S_aux), make_shape(num_heads)); - Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{}); - - if(params.mainloop.ptr_S_aux && threadIdx.x < num_heads) { - sS_aux(threadIdx.x) = gS_aux(threadIdx.x); - } - - // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - cute::cluster_wait(); - } else { - __syncthreads(); - } - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); - - if (warp_group_idx == 0) { // Producer - cutlass::arch::warpgroup_reg_dealloc(); - - // The pipelines for AppendKV and main attention are different, since e.g. main attention - // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load - // KV_new. Since the pipeline states are different, we have to manually sync to make - // sure the two pipelines don't race when accessing smem_k and smem_v. - PipelineState smem_pipe_write = cutlass::make_producer_start_state(); - PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); - int work_idx = 0; - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; - if constexpr (SingleProducerWarp) { - if (warp_idx_in_warpgroup != 0) { return; } - } - if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } - - cutlass::arch::wait_on_dependent_grids(); - - // Load Q, K, V - for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { - - auto block_coord = work_tile_info.get_block_coord(params.scheduler); - SeqlenInfo_t seqlen_info{ - get<2>(block_coord) /*bidb*/, - get<0>(params.mainloop.shape_Q), - !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), - get<0>(params.mainloop.shape_K_new), - params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, - params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary - }; - if constexpr (AppendKV) { - bool tile_new_valid = mainloop.load_kv_new( - params.mainloop, pipeline_k_new, pipeline_v_new, - smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx); - if (tile_new_valid) { - // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); } - cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::AppendKV) /*id*/); - // if (threadIdx.x == 0) { printf("Producer: After sync\n"); } - } - } - auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { - scheduler.prefetch_next_work(params.scheduler, work_tile_info); - }; - // pipeline_vt won't be used if we don't need to transpose V. - mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, - shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx); - } - mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); - } else { // Consumer - cutlass::arch::warpgroup_reg_alloc(); - - // Initialize matmul objects. - TiledMmaPV tiled_mma_pv; - - PipelineState smem_pipe_read; - PipelineState smem_pipe_read_new; - // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v - // (like in Cutlass's gemm) because the read and release pipeline states are always the same. - - scheduler.init_consumer(); - mainloop.mma_init(); - - int work_idx = 0; - CUTLASS_PRAGMA_NO_UNROLL - for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - // get_next_work will be called before the epilogue - ) { - auto block_coord = work_tile_info.get_block_coord(params.scheduler); - int const bidb = get<2>(block_coord); - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.mainloop.shape_Q), - !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), - get<0>(params.mainloop.shape_K_new), - params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, - params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary - }; - if constexpr (AppendKV) { - bool tile_new_valid = mainloop.store_kv_new( - params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new, - threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord); - if (tile_new_valid) { - // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); } - // We need this sync so that the gmem write from the consumers is visible to the producer - // that might do TMA read after that. - asm volatile ("fence.proxy.async.global;"); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::AppendKV) /*id*/); - // arrive is enough, we don't need sync. The producer will sync, which means - // after that sync we're guaranteed that the AppendKV pipeline have finished - // loading and consumer smem_k and smem_v. - // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } - } - } - // If there's tanh softcap, the scaling will be done before tanh. - float softmax_scale_log2 = params.mainloop.softmax_scale_log2; - if constexpr (Is_FP8 && !Has_softcap) { - int const bidh = get<1>(block_coord); - int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; - float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; - float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; - softmax_scale_log2 *= q_descale * k_descale; - } - flash::Softmax softmax(softmax_scale_log2); - // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); - bool tile_valid; - if constexpr (!LargeHeadDimV) { - tile_valid = mainloop.mma( - params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); - } else { // mma_pv might not compile if !LargeHeadDimV - if (warp_group_idx == 1) { - tile_valid = mainloop.mma( - params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); - } else { - tile_valid = mainloop.mma_pv( - params.mainloop, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); - } - } - // Do this here before the epilogue so that the next tile is ready to go. - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); - if constexpr (Split && Varlen) { - if (!work_tile_info.is_valid(params.scheduler)) { // Last tile - cutlass::arch::launch_dependent_grids(); - } - } - if (tile_valid) { - // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, - threadIdx.x - MmaThreadOffset, block_coord); - } else { - // Write 0 to gO and -inf to gLSE. - epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); - } - } - epilogue.store_tail(); - } - - } - -}; - -} // namespace flash diff --git a/flash-attn/flash_fwd_launch_template.h b/flash-attn/flash_fwd_launch_template.h deleted file mode 100644 index 2c9363300a594547b8f4aa3d99c4f0933d9e0784..0000000000000000000000000000000000000000 --- a/flash-attn/flash_fwd_launch_template.h +++ /dev/null @@ -1,231 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" // For device_kernel -#include -#include "cutlass/cluster_launch.hpp" -#include "cutlass/kernel_launch.h" - -#include "static_switch.h" -#include "flash.h" -#include "tile_size.h" -#include "tile_scheduler.hpp" -#include "flash_fwd_kernel_sm90.h" -#include "flash_fwd_kernel_sm80.h" -#include "mainloop_fwd_sm90_tma_gmma_ws.hpp" -#include "mainloop_fwd_sm80.hpp" -#include "epilogue_fwd.hpp" -#include "heuristics.h" - -using namespace cute; - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); - static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); - static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen"); - static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; - static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor; - using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; - using ElementS = cutlass::bfloat16_t; - - // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); - static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); - static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); - static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); - static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); - static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); - static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); - - using TileShape_MNK = cute::Shape, Int, Int>; - using TileShape_MNK_PV = cute::Shape, Int, Int>; - using ClusterShape = cute::Shape, _1, _1>; - using CollectiveMainloop = std::conditional_t< - Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 - >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; - - static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; - using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, - std::conditional_t, - flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> - > - >; - using SchedulerSingleTile = flash::SingleTileScheduler; - // If Split then we probably don't have enough work for PersistentScheduler to be useful. - // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better - // since we'll avoid launching a bunch of thread blocks that immediately exit. - // On Sm80, noncausal persistent seems a bit slower. - static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split)); - using Scheduler = std::conditional_t; - using AttnKernel = std::conditional_t< - Arch >= 90, - flash::enable_sm90_or_later>, - flash::enable_sm80_to_sm89> - >; - - bool const is_varlen_q = params.cu_seqlens_q; - bool const is_varlen_k = params.cu_seqlens_k; - bool const is_varlen_k_new = params.cu_seqlens_knew; - int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; - int batch_q = !is_varlen_q ? params.b : 1; - int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1; - typename CollectiveMainloop::StrideV v_strides = - cute::conditional_return( - make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0), - make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0)); - typename CollectiveMainloop::Arguments mainloop_args { - static_cast(params.q_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_Q - {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q - static_cast(params.k_ptr), - {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, - params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K - {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K - static_cast(params.v_ptr), - params.dv, // headdim_v - v_strides, // stride_V - static_cast(params.knew_ptr), - {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new - {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new - static_cast(params.vnew_ptr), - {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new - static_cast(params.qv_ptr), - {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv - static_cast(params.rotary_cos_ptr), - {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter - {params.rotary_dim / 2, _1{}}, // stride_rotary_cos - static_cast(params.rotary_sin_ptr), - {params.rotary_dim / 2, _1{}}, // stride_rotary_sin - params.is_rotary_interleaved, - params.page_table, - // if page_size is not set, avoid dividing by zero - {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table - {params.page_table_batch_stride, _1{}}, // stride_page_table - params.scale_softmax, - params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, - {params.q_descale_batch_stride, params.q_descale_head_stride}, - {params.k_descale_batch_stride, params.k_descale_head_stride}, - {params.v_descale_batch_stride, params.v_descale_head_stride}, - params.window_size_left, params.window_size_right, - params.softcap, - params.num_splits, - params.kv_batch_idx, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, - params.leftpad_k, params.seqlens_rotary, - static_cast(params.s_aux_ptr) - }; - typename CollectiveEpilogue::Arguments epilogue_args { - static_cast(params.o_ptr), - {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O - {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O - static_cast(params.oaccum_ptr), - {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial - static_cast(params.softmax_lse_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE - static_cast(params.softmax_lseaccum_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial - params.h_k, - params.cu_seqlens_q, params.seqused_q - }; - - int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); - int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{})); - num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{})); - typename flash::TileSchedulerArguments scheduler_args { - num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, - params.h / params.h_k, - params.seqlen_q, - params.seqlen_k, params.d, params.dv, sizeof(Element), - params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, - }; - - if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); - } - - int device; - CHECK_CUDA(cudaGetDevice(&device)); - typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ - mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args - }); - - dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); - dim3 block_dims = AttnKernel::get_block_shape(); - int smem_size = AttnKernel::SharedStorageSize; - // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); - // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); - // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); - // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); - // Get the ptr to kernel function. - if constexpr (size(ClusterShape{}) > 1) { - void const* kernel = (void const*) cutlass::device_kernel; - if (smem_size >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); - cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; - cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params); - } else { - auto kernel = cutlass::device_kernel; - if (smem_size >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // kernel<<>>(kernel_params); - cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); - } - CHECK_CUDA_KERNEL_LAUNCH(); -} - -template -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); - static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; - using T_out = std::conditional_t; - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { - VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { - static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; - VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - BOOL_SWITCH(use_one_mma_wg(params), Use_one_mma_wg_, [&] { - // Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 128 head dim and hopper - static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && kHeadDim == 128; - - // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; - - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); - }); - }); - }); - }); - }); - }); - }); -} diff --git a/flash-attn/flash_prepare_scheduler.cu b/flash-attn/flash_prepare_scheduler.cu deleted file mode 100644 index 7093fff32b673c06bd1dd7b09bcdf3c4deddaa53..0000000000000000000000000000000000000000 --- a/flash-attn/flash_prepare_scheduler.cu +++ /dev/null @@ -1,124 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#include "cutlass/fast_math.h" -#include "cutlass/barrier.h" -#include "cutlass/arch/barrier.h" - -#include "cutlass/arch/grid_dependency_control.h" - -#include "flash.h" - -namespace flash { - -__global__ void prepare_varlen_num_blocks_kernel( - int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, - int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, - int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, - int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, - cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, - int* const tile_count_semaphore, - // int* const num_m_blocks_ptr, - int* const num_splits_dynamic_ptr, - bool enable_pdl) { - - static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; - static constexpr int kSmemSize = 1; - // Assume that there's only one block in the grid - __shared__ int total_blocks_smem[kSmemSize]; - - // There's only 1 block in the grid, so might as well start launching the main attn kernel - if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } - - if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } - __syncthreads(); - - if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } - - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - - auto get_num_m_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; - int seqlen; - if (seqused_q) { - seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; - } else if (cu_seqlens_q) { - int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = seqlen_q_static; - } - seqlen *= qhead_per_khead; - return batch_idx < num_batch && lane < kNumBatchPerWarp - ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; - }; - - auto get_num_n_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; - int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; - int seqlen; - if (seqused_k) { - seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; - } else if (cu_seqlens_k) { - int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = seqlen_k_static; - } - int seqlen_new; - if (cu_seqlens_k_new) { - int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; - int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); - seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; - } else { - seqlen_new = seqlen_k_new_static; - } - // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } - seqlen = seqlen - leftpad_k + seqlen_new; - return batch_idx < num_batch && lane < kNumBatchPerWarp - ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; - }; - - int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int bidb_start = kNumBatchPerWarp * warp_idx; - int num_m_blocks = get_num_m_blocks(bidb_start); - int num_n_blocks = get_num_n_blocks(bidb_start); - - int total_blocks = num_m_blocks * num_n_blocks; - // Warp sum - #pragma unroll - for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { - total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); - } - if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } - __syncthreads(); - total_blocks = total_blocks_smem[0]; - // 10% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); - // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); - } -} - -} // flash - -void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, - int blockM, int blockN, bool enable_pdl) { - // Only support batch <= 992 (32 warps, each with 31 batches) - int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( - params.seqlen_q, params.seqlen_k, params.seqlen_knew, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, - cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, - // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, enable_pdl); -} diff --git a/flash-attn/heuristics.h b/flash-attn/heuristics.h deleted file mode 100644 index 43d06f54825934d05c801854371e91d2ce843ae1..0000000000000000000000000000000000000000 --- a/flash-attn/heuristics.h +++ /dev/null @@ -1,65 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include "flash.h" - -inline bool use_one_mma_wg(Flash_fwd_params const& params) { - return params.arch >= 90 && params.d == 128 && - params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64; -}; - -inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { - // If varlen, we don't actually know seqlen_q but only max_seqlen_q. - if (varlen_q) return true; - // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM - auto round_up = [](int a, int b) { return (a + b - 1) / b * b; }; - float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM)); - float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM)); - return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency; -}; - -// Find the number of splits that maximizes the occupancy. For example, if we have -// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is -// better than having 3 splits (efficiency = 0.67). However, we also don't want too many -// splits as that would incur more HBM reads/writes. -// So we find the best efficiency, then find the smallest number of splits that gets 85% -// of the best efficiency. -inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { - // If we have enough to almost fill the SMs, then just use 1 split - // However, in the case of super long seqlen where each head of KV doesn't even fit into - // L2 (we assume that L2 size is 50MB), we want to split. - if (total_mblocks >= 0.8f * num_SMs) { - int const size_l2 = 50 * 1024 * 1024; - // Only split if there are enough queries to go over the KV at least twice - // Don't split if causal - if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { - return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); - } else { - return 1; - } - } - // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. - if (num_n_blocks <= 4) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); - float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - float n_waves = float(total_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if (eff > max_efficiency) { max_efficiency = eff; } - efficiency.push_back(eff); - } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; - } - } - return 1; -} diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu deleted file mode 100644 index 8cf5ed22c2ab6f2b32ffa89fffce0f7227d91901..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu deleted file mode 100644 index 54fc817c9ca4a9cfb9bde98593f3c4bf33a6618f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu deleted file mode 100644 index 01b28afbc0cc82b159e60719fb8777d71f61e50b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu deleted file mode 100644 index 5f4b207c53e4cec0edf0d50166fe2fd8c1a5d120..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu deleted file mode 100644 index 5a4d6478cb24b4ffebd931315fcf6d72475d10be..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim128_bf16_sm90.cu" -#include "flash_bwd_hdim128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu deleted file mode 100644 index 6833e03dc841c25c4c1dbfab4a0fce5a0f28b4a8..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu deleted file mode 100644 index d02d370b5032615979ca80df1a272caad1e94800..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu deleted file mode 100644 index 7981d11c2a267df33925c685952090e29f039ba3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu deleted file mode 100644 index b474ace825b092c7805b4c94d033cba00fbc819f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template<> -void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu deleted file mode 100644 index cc52974e4c7f70384faecec0c21a8a7260bc0a2d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim128_fp16_sm90.cu" -#include "flash_bwd_hdim128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu deleted file mode 100644 index 41b4b01086a1aa809e485d5efe094cabd7d884fa..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu deleted file mode 100644 index 6e81ff7346a5acf9f16b786ae72b6e163eda0e79..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu deleted file mode 100644 index 15265f25b26c03ce6e8fb3e76a949b55eaac18bd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu deleted file mode 100644 index 0a363983ec586a9e416590339dadec8c698cad62..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu deleted file mode 100644 index 1504a28475d09029661aeaa73cb3e85f0b355074..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim192_bf16_sm90.cu" -#include "flash_bwd_hdim192_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu deleted file mode 100644 index 77ab70a00c0fb33a4c5baf2404a7097fa98bbaa3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu deleted file mode 100644 index 67a99dfc5d37c1034b2c9df8a5ad64e4b2784d08..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu deleted file mode 100644 index c1a72b0bf27b47f127eb8c110d46cb739d7ea78e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu deleted file mode 100644 index 75291357b89dfbb1989937a7e3c4fa9c850b0841..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template<> -void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu deleted file mode 100644 index fdd34babb46f0e4bc5f0aa2c8ba5e07ea33b8b7d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim192_fp16_sm90.cu" -#include "flash_bwd_hdim192_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu deleted file mode 100644 index fa0bc55f9a6f995eb7b59a8698edf367eba6bfcd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu deleted file mode 100644 index 937b03c5051f28b5f1c6f6595bf29d331cd296fb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu deleted file mode 100644 index 43ce3b615548d68edb40a3df9b2b9c6ff8eb39ae..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu deleted file mode 100644 index b789174b1513a3793a2cac3a4a4eb0bb61ae5d3f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu deleted file mode 100644 index 7d68af3d2764b0b0e53c6915d7e47dc7e0a4e697..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim256_bf16_sm90.cu" -#include "flash_bwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu deleted file mode 100644 index c6aa2475abf7c05b548ca8a1ca40e203d8bf6aa3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu deleted file mode 100644 index 970f4d6252f679e841d105eccc88a0f1eb2dee87..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu deleted file mode 100644 index ba90899624df01f8e1173baf28eecbee9bae9450..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<80, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<80, cutlass::half_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<86, cutlass::half_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu deleted file mode 100644 index 1a16a4d428c76dcd1cb29d6fb948afa42f322b9f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template<> -void run_mha_bwd_<90, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256<90, cutlass::half_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu deleted file mode 100644 index cd471362dfcc43984ef5b9e08c9b250388900784..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim256_fp16_sm90.cu" -#include "flash_bwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu deleted file mode 100644 index 7b242432c662e52defd583f4ab676231903063e2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<80, cutlass::bfloat16_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<86, cutlass::bfloat16_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu deleted file mode 100644 index 6a90f01b86a5e22f9bbef3f5ec7b5d692e1450df..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<90, cutlass::bfloat16_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu deleted file mode 100644 index fc6116b13e797248cd9d1d1eaf41da8f18181727..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<80, cutlass::bfloat16_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<86, cutlass::bfloat16_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu deleted file mode 100644 index 481c1a7e895c6136a5061cfa3776019e642ae2bf..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<90, cutlass::bfloat16_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu deleted file mode 100644 index 414bca23c93bd1a392237ec5a9c33daa8a540044..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim64_bf16_sm90.cu" -#include "flash_bwd_hdim64_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu deleted file mode 100644 index bd0205076c0e23b33b9a32f21294cccdf14b486d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<80, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<80, cutlass::half_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<86, cutlass::half_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu deleted file mode 100644 index 4a927c6322bbd6c06fd035c0225937263b723535..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<90, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<90, cutlass::half_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu deleted file mode 100644 index 401e817bdd38274fcb20c722ce208e64ca41ad36..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<80, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<80, cutlass::half_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<86, cutlass::half_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu deleted file mode 100644 index d726b62c1fffa623570a1e124a6eb9eda58f2c5c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template<> -void run_mha_bwd_<90, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64<90, cutlass::half_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu deleted file mode 100644 index 10697f012d82724bfce5db1c36d420b30bf28cea..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim64_fp16_sm90.cu" -#include "flash_bwd_hdim64_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu deleted file mode 100644 index 326a796ed257057d3e86b8e0f30e6852ae1961fa..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<80, cutlass::bfloat16_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<86, cutlass::bfloat16_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu deleted file mode 100644 index f63535737a3f198f28673f33b108ebf67663f4ea..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<90, cutlass::bfloat16_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu deleted file mode 100644 index 63522b87f9163aa73530acddab50c2059e763b6b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<80, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<80, cutlass::bfloat16_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<86, cutlass::bfloat16_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu deleted file mode 100644 index b16a8d376d259f3ecf158de3708b2d7704a79d06..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<90, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<90, cutlass::bfloat16_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu deleted file mode 100644 index f6593e1801a05ecbe5c79994a1155745d418df13..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim96_bf16_sm90.cu" -#include "flash_bwd_hdim96_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu deleted file mode 100644 index 92ff4ad1ad9a72e941a9e2b8bdc6160eb28ecfe6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<80, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<80, cutlass::half_t, false>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<86, cutlass::half_t, false>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu deleted file mode 100644 index 56ab6760d68cea3655273af8dc5ba42a571ddf25..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<90, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<90, cutlass::half_t, false>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu deleted file mode 100644 index 0896abb5a6b399541555aa2589d1886c5f39a798..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<80, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<80, cutlass::half_t, true>(params, stream); -} -template<> -void run_mha_bwd_<86, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<86, cutlass::half_t, true>(params, stream); -} -#endif -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu deleted file mode 100644 index 7d92efc0ce94642ca24f533ef6ff414be16a6d9b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template<> -void run_mha_bwd_<90, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96<90, cutlass::half_t, true>(params, stream); -} -#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu deleted file mode 100644 index 82265ca9edae8df97daa3ecefa5c211bec734a67..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_bwd_hdim96_fp16_sm90.cu" -#include "flash_bwd_hdim96_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu deleted file mode 100644 index affc7a4dd96b51efc4a3f5e452731a6d94ac33d0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu deleted file mode 100644 index 7e13614bfeadbf03fe6f7b6d74d425c010f4bcd2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu deleted file mode 100644 index 670041341bcf3f883cd9933801434faa52209228..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu deleted file mode 100644 index f315fbb45454d3d6ffb2f0471e0e8fa5ad27b575..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu deleted file mode 100644 index bde3024a4a68ee6ce74a98b48e450c2cdbd12078..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu deleted file mode 100644 index 40afcfd68b17a4fd73e7053aae7d7d2b83a49c46..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_bf16_paged_sm80.cu" -#include "flash_fwd_hdim128_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu deleted file mode 100644 index 2724463e621137c98ea37b303c3f70307fada8ae..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu deleted file mode 100644 index a38a1d5cf33d3839adff7452e5a8ff609cec2638..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu deleted file mode 100644 index 284eeba18235c2acf4b6eccce2f670ea19e27495..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 0c40ddba8fe6814103fefab86653c08941ce8003..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 5bed85a0ddc4cfabf33abc1f9147f699ca5add97..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_bf16_paged_split_sm80.cu" -#include "flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu deleted file mode 100644 index 4fb8f71d01e8faf7648ac1988dd4f908870f25b5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu deleted file mode 100644 index cc89c4d5d2559acc5a53e9d163494920af7140ef..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu deleted file mode 100644 index 3a236b712c4d0c4660ee6113f5fe39d661537498..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index 8449104c5aa188ab4416ab85ef9d3f7383605e7e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu deleted file mode 100644 index b152b90bab7c9f50bce1cd5e3d6949182308a6d9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu deleted file mode 100644 index 8cc4fed1739eb7b50b8479f7aa8ec608d5667816..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu deleted file mode 100644 index eee51630db1e8c1d65ee5453c8584d8f25b8ac02..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_bf16_sm80.cu" -#include "flash_fwd_hdim128_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu deleted file mode 100644 index 1db3f1e6d803ed909f727a05ca81fc7e74c3b5af..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu deleted file mode 100644 index 9b3e294f1b3766502cd4c69cf2ddd89d3424b188..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu deleted file mode 100644 index 07bd687fc344c96299d80a6a991e1633ef03b0f8..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu deleted file mode 100644 index 5f44833b10da9db301a67de6247d4ff36b530e47..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu deleted file mode 100644 index f51f96d92665a81732981ae15637a9b430a29c42..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_bf16_split_sm80.cu" -#include "flash_fwd_hdim128_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu deleted file mode 100644 index 9f95ca29f6be2cebf53b454ece8eea0bb418bc3a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu deleted file mode 100644 index ad97737d4f3d2c02521338c2528d2670ced4a17d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index d77d37ec0414df8f00d8ca3ad4fcb4abc3db6e1f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu deleted file mode 100644 index ae05c7ce5f0a87b7081478571d466bd4ec5b7a7c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index bc52a9f356fd3640b7e4c55995026758d4508f2c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu deleted file mode 100644 index 480d485d06908cd945e3b90038f471701478d6a0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index d3da5f4e6656c697f6da20e6dfb3d5e0bfa78d08..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu deleted file mode 100644 index 1c1c2d8207fd419136dbbc98429e8d0f32048f88..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu deleted file mode 100644 index 371d933e3e1e8cdd763644cb451aa036aa666497..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu deleted file mode 100644 index 7491148dcdebc67ed4f01ddcfa8ecebc86e97b9b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu deleted file mode 100644 index d04159a62a0aeb11465ca77be37057aaef1b78ac..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu deleted file mode 100644 index 28ad6c149637ab70a0f461b146758df6b119d6cd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu deleted file mode 100644 index 7afb267e3ebf9e4b42b695c07a80aafacce23f43..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu deleted file mode 100644 index 69758584cb63da1305a294e555cff087e57bd716..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 3be45956bb4febcea1ec0d835da3c774664608fa..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu deleted file mode 100644 index 29b97d0e149b8ee261a2b08d9e00dee1bf40d70c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_fp16_paged_sm80.cu" -#include "flash_fwd_hdim128_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu deleted file mode 100644 index 698095dad6a0a37a06e708b92fbb4bb62716e3cb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu deleted file mode 100644 index 16d443a9ad170e6471f71caf44ea5f20f0b17f00..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu deleted file mode 100644 index 1e8f6af71bdabf35553d2e7c972f84e73c09e7b5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 4ec688861124107eceb82634baa8698d19fa1578..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu deleted file mode 100644 index b745a57a2042d74ab701635ce9e54897f353937d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_fp16_paged_split_sm80.cu" -#include "flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu deleted file mode 100644 index 670b5952d9d1e4a65f0d3f7999befa19ce33d1c0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu deleted file mode 100644 index b9778dc92e1b0c39197fa54ef7bf3de344e25505..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index 446e917c79536e41113457299df62384a312faf6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu deleted file mode 100644 index fd62a2c54352cdcf60dda75e0a7d847a9d2664dc..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu deleted file mode 100644 index 0a397f4acf2ab3a1c69aaf643b00d5ec9f65c1dd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu deleted file mode 100644 index 2292272ac2a45410117b72d2d1f9c8547e70ce4b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_fp16_sm80.cu" -#include "flash_fwd_hdim128_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu deleted file mode 100644 index 4d3c553e296a5cc9fbaec634a7916af400fa9130..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu deleted file mode 100644 index 77621846ffeb8ee66e4cb1a225c5ce85759ad905..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu deleted file mode 100644 index 7d217ac273384dc2fa989e908ec02bc28d29aa72..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu deleted file mode 100644 index 0b6430abc2ff56903f969d6f53ef88e0ff0604c9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu deleted file mode 100644 index 58891d23eb60fe64a533a05b6cf6e888b368b2f2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim128_fp16_split_sm80.cu" -#include "flash_fwd_hdim128_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu deleted file mode 100644 index ea1e266f8d4cd2621be685f4cd2aced78a8dca01..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu deleted file mode 100644 index 2d7488fefe29e627e4c7c9192db54b764a74e7ba..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu deleted file mode 100644 index 8718571e30c2b1177766a0a81b0b8a86d361d6da..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu deleted file mode 100644 index f7dfc18fc1ebace9b4c6e8527b4f4a1f3a70ec7b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 935f5a0fe6084fc3dce8bf93eaa582b6cf59b32f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu deleted file mode 100644 index 3f4d858ff572ef21225c2a41bf41b3891806a37b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index 54d720efeb32b028401a6d619213ce6f31aec598..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu deleted file mode 100644 index b9b93af4fc50033570b8345c86b7731fc506fb85..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu deleted file mode 100644 index 39d9167b9f19817a25ed02f685c78af4a54e322b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu deleted file mode 100644 index 0f86458012a8649975a0d611234749ab116fbb7e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu deleted file mode 100644 index bd6f4df8f696c154e1d8a735d6b7b2c1ebce8087..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu deleted file mode 100644 index 1824b86c64ce3786316879b9b02eb7cc3d5c6e9b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index 87dd01725a540f59c29bff4f89b1eca9cf7f1926..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu deleted file mode 100644 index 6594d56012392b19ae445d2f6e72501f81a5b837..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index d7dc84ebc1c8f009117f97aa6b941029f8560aca..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu deleted file mode 100644 index b9d6e54cbed3fa19cd04ae6460355b9f5b6bad62..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index a8c47652ec15f8c395fe0662604f490a805291d2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu deleted file mode 100644 index 32d17c7665d612a0524e0cc47cb971c8f12c903f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu deleted file mode 100644 index 365017c256d027bc48a01077db434a17c0bed18c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu deleted file mode 100644 index 82cfdf040b0dd0279a4245bf1c4ee61bca855f37..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu deleted file mode 100644 index f3254936a47f004f0a261138645e3ff4d2a4dcc0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu deleted file mode 100644 index 931a6dbf86985659cda2fb7bfb83c8a3ec83a398..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 5c8877a756dd56a03ffaa3ce9dec728a8a65f752..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu deleted file mode 100644 index 1e230ab084bf830cbabc4076526355e63452f7fb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 03716c862372e3d10310d8c208bc7d5f38a6d430..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu deleted file mode 100644 index 54c66c9552e994ecd8dc7b6e0397138fdbb04cd1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index e5e0ec47db1ae18706d39fb992d38de0085ab8bb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu deleted file mode 100644 index e4411b5db32833541e2e1aeab5526e9415a78c74..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu deleted file mode 100644 index 157ed06dddf66ff9cd465b5ac700d7c319007339..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu deleted file mode 100644 index 7ef5adc9e85d9ef4ff3dc7fc2e78e693dbf6cdb9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu deleted file mode 100644 index bf8386b82976f8ca6ada1118e6f4073ae4624413..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu deleted file mode 100644 index cbc6f988424eb5c10494f0133b2a148197f6171a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu deleted file mode 100644 index d5aa15b5c8ca63eede4368b0bd478a3f9e5760aa..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu deleted file mode 100644 index b8593612df3d3054d7bfb52703552a718d58b384..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu deleted file mode 100644 index a03514d919b075cf4d092a878e6cd22d05294441..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu deleted file mode 100644 index a7f3f9d0fe4d247cf8d16f43dcc44476cb40bc85..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_bf16_paged_sm80.cu" -#include "flash_fwd_hdim192_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu deleted file mode 100644 index df547749e93a135c6c17ba1a1b1615b001dde447..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu deleted file mode 100644 index 1ddb191620981189062f3d745e159ef4c66e9df8..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu deleted file mode 100644 index cefffcd2169977a9c1cf54ba5e517b99c276422e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 3d4333b9e1f0fb136edaf54378c5dd688292ccba..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 298b0cabafc3ec75f84ae7e4e3484db325267e5a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_bf16_paged_split_sm80.cu" -#include "flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu deleted file mode 100644 index 35a2abef8c9e02931e6a8b4847f614fdf3ac60d4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu deleted file mode 100644 index 99e34ac0bfb09fb5b4f386d604122c42bff2e38c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index ed1cf22d5c410ac8f47e82314fd70f38e18a152c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu deleted file mode 100644 index 4527d9a27931e4d244b25e2206ad8c1861d8ea7c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu deleted file mode 100644 index 41fcf8001701b6581753b9184a478967826cdbe6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu deleted file mode 100644 index d8b44171204e33fb1cc2478e6a3c0615fc07462f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_bf16_sm80.cu" -#include "flash_fwd_hdim192_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu deleted file mode 100644 index 704cbcb337e95ac6f64b10e69be538a719579f2d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu deleted file mode 100644 index e0ea082156bb79dad4694b06eb3ece76c4d3ab7f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu deleted file mode 100644 index a9c00408a8b2c919a270fd1db292d1e6575fd3a0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu deleted file mode 100644 index 1497e7aa8430b300f003335d7b9fa90e8e40fc1a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu deleted file mode 100644 index 95adb9e2b4f99634b6623cfde868d74d1064777e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_bf16_split_sm80.cu" -#include "flash_fwd_hdim192_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu deleted file mode 100644 index c66ea9baca1b3935de59247efeafe2f20b48632f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu deleted file mode 100644 index a7e472b478b137bf9819c5cb577cc4f000e3b964..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index 9f090aeeda8b46ec40c3aae73e1986288e868eb5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu deleted file mode 100644 index 2205168a67f4cdbea59908f30076937f5b5f0372..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index 2a01898b5604a68c52bbeb300c46bbaaaa6026b4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu deleted file mode 100644 index 888e241a9f1f1f6ddec5cf8247dab23fe16e089e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index 2a6bde7a39fbd40d2d528cb0ea242e9140d09bdf..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu deleted file mode 100644 index 3d315187b2d4585ede8e6665421e6d3ed743036f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu deleted file mode 100644 index 3c3d0938034d39d70293feff0b5e4e2f22701748..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu deleted file mode 100644 index 4ca103566d673d97f7a1324eecb6795105ced521..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu deleted file mode 100644 index 16debf27799fb8f35873208b60a77e0e39a6b596..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu deleted file mode 100644 index 43c2615718e27ff689ff6c19e774732e77625531..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu deleted file mode 100644 index d9d483838f11c228ed21c3e7baf8883cca8811e4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu deleted file mode 100644 index 70543998d948f776cf420d0cb6525a087e0460cb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu deleted file mode 100644 index c30c7e3b8b9022cab1e486032ed16b6220b794ec..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu deleted file mode 100644 index 4bc8c4264b07fb1495c9db298839f3a4e95a1ede..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_fp16_paged_sm80.cu" -#include "flash_fwd_hdim192_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu deleted file mode 100644 index 7ae26e69c96bf2a4fe3f3d8f9cab8c93e490dbb1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu deleted file mode 100644 index 155b5a539fd9345905e3fdb7162898bc8fdbca80..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu deleted file mode 100644 index 3e6173c31c23a290ee1f13d97f2bfbc22798bdee..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index e1e3191a202e0033c3e7079a3742bf12ff22473e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 6316f81e7b9cec8b08f58981cc5f5d279b958eea..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_fp16_paged_split_sm80.cu" -#include "flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu deleted file mode 100644 index 8272ecb76cb40165e09f0a428cf89f5ace6ce456..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu deleted file mode 100644 index 74606c3937381c01a0a01eea84825a0e0cc2510b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index 89a58502b375a91a441cc1131d1ecd7e1ff0673c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu deleted file mode 100644 index b13373806a1a008a9ab20e89bd0f065864f3c955..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu deleted file mode 100644 index 1335fad7f2b66c58822bc8be52134aabdc0099e1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu deleted file mode 100644 index cc59a01dae9aed4d8a868c79c7aa7b4328595b41..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_fp16_sm80.cu" -#include "flash_fwd_hdim192_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu deleted file mode 100644 index 18c31bdfc0ea9dc93230d82826cd86782aaea68b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu deleted file mode 100644 index 18a5603cf1f51f5c97469ba94cf3adb872c90e7c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu deleted file mode 100644 index 4e99c7db0276402198aa4eabd1f5228599907a28..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu deleted file mode 100644 index 82f8204aa66609bba5b845979658e840ce18eb27..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu deleted file mode 100644 index 4314bc82236e0e8897dbaba5aeadced086fd665b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_fp16_split_sm80.cu" -#include "flash_fwd_hdim192_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu deleted file mode 100644 index cb851a77110c7d0eb8b43840a083fc795139f707..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu deleted file mode 100644 index ae2871c16558ebe5c67d7321be8a31811ae651fe..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu deleted file mode 100644 index ed24fbffef9e2e53e2e490cb034457c7a1e5bdd1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu deleted file mode 100644 index ffca9c7f8fe36e11efc1cebe9de6cabb4dd493c9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu deleted file mode 100644 index 57a06bd6e66002cf7f91736cfab6c7f0505c08dc..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu deleted file mode 100644 index 47db0cac3dd20bc8e0ce98d1c02f101166d169ce..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_bf16_paged_sm80.cu" -#include "flash_fwd_hdim256_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu deleted file mode 100644 index ccdcf21e492929ca83e2c33a3970fe7b11314713..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu deleted file mode 100644 index c2bc7787765e1929bfe3cee0f1cc6aadbac0c449..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu deleted file mode 100644 index 6bba953fc69863996ae757310c10b1c5b39354b7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 25c96174c79bf203dea2bb301dd4ebbdbd5a6457..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 7f62ded549444ab65960a3b9279fb735a22f30ee..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_bf16_paged_split_sm80.cu" -#include "flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu deleted file mode 100644 index f172239e5b95686b51dd98271d58769108f4055e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu deleted file mode 100644 index 9dde6adb04b6c02c0b69d9baa12d082d99b1b06b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index 2317adef8c53ab5615782ee867ef2dfa0bdd5ad2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu deleted file mode 100644 index b9b3b74867e28f5676e827010c84e636b3f2e346..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu deleted file mode 100644 index c57a5a30abbd8d820b9ae9aef1b4fa14fa56e4b4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu deleted file mode 100644 index 12851474031205db74385d6fbaaf097b6cbf548b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_bf16_sm80.cu" -#include "flash_fwd_hdim256_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu deleted file mode 100644 index 4f59a6aea92267476a8c84eedd9ce9abd9f51a71..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu deleted file mode 100644 index 2c2de1574ac08ef01b7353be8d0575f35997269a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu deleted file mode 100644 index 0dbd062c79f8e503ceb8ee99cc2124930c753808..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu deleted file mode 100644 index bee54c702de1b1b2727bd58b55eb600af8f52785..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu deleted file mode 100644 index 2007abeeaba5ec81cf6eaae99444e09cc549f48c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_bf16_split_sm80.cu" -#include "flash_fwd_hdim256_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu deleted file mode 100644 index c02e6833494d0ba4f9795b75262e06a8cfd36215..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu deleted file mode 100644 index 02b50b98b8d1a31492b772683c5a4b566e65921f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index 6599de63bbd654300a2833eebd2f00bfef3d0b1a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu deleted file mode 100644 index a1cdc775cbb78a85ba25016457ebc68dda3f6241..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index 6d01be60f58ab0a6fc1bf23f48235a196d7da29d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu deleted file mode 100644 index 968bbf36f838107dc345f65f4d6532ebce037699..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index d564a622111d492e3bb2f6017d1671dc4dde1615..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu deleted file mode 100644 index cb5bccc176cb570d3be44d3820fce4459300a821..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu deleted file mode 100644 index 146a7bc3430bccee015f6e3dbc120e7e6064fac5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu deleted file mode 100644 index a195e0931c00565f172c144c4fc2b9cf9f7f9e43..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu deleted file mode 100644 index 045fc71bedb78c444d2cccd4eba49e4b6758d666..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu deleted file mode 100644 index a31da2eddf4ea3ccc95084e5a9779494c5022013..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu deleted file mode 100644 index 7382b58a2313acba7ffce57acd16acd0070b135f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu deleted file mode 100644 index 87ca31ce9027333926b94095d6f4ef3745f8a602..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 60f4d6ebbf6ec4dae3b1d50a47f06d3d4151bef2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu deleted file mode 100644 index 593bba5280a61c75520dd84f2ae616c2126fb43f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_fp16_paged_sm80.cu" -#include "flash_fwd_hdim256_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu deleted file mode 100644 index e0d5d318bcd26d9370e8d0e4ac556e32270e9c01..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu deleted file mode 100644 index dec7db046bd43d2fd30970e1eff42f47ea5df83f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu deleted file mode 100644 index 7b71f4352260581968c725181779601f71771e8b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 08fc989af8b1feaf84797b4294c61a23231c922d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu deleted file mode 100644 index f1adde60be67e2646eb683195808f4295785a83e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_fp16_paged_split_sm80.cu" -#include "flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu deleted file mode 100644 index 2cc8b5b86d443434af5035210ce09a60ffcffae3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu deleted file mode 100644 index 644e268469fa4ed900b56dd203220c9155ca0774..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index 1ebcec8b3fef8ebd9f2ea03e729500e611f5773b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu deleted file mode 100644 index 780ade7f6b8451d145ae3ecf0ecf1c75fae842a9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu deleted file mode 100644 index bfcffe2a39a9f516dc79d06aea66ca7aff43ae08..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu deleted file mode 100644 index 5724ae3309c6b0353abc058d0a26b4da8784b675..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_fp16_sm80.cu" -#include "flash_fwd_hdim256_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu deleted file mode 100644 index ba4ba78ad49ea806df748e538010b83c081987d7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu deleted file mode 100644 index f04260ba4f1acfb2049b7f623e9b667ca0c97fdf..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu deleted file mode 100644 index 33c78e53059d96fc7a0d556ca80633debc504634..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu deleted file mode 100644 index 8388420921ee70bd1218e272be90fd5180baef22..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu deleted file mode 100644 index fe6b77bb8d23cab3d3e764728966c0459c826bdb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim256_fp16_split_sm80.cu" -#include "flash_fwd_hdim256_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu deleted file mode 100644 index 8d037153cbbfaab52cf3a64010e9f068f525ba1b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu deleted file mode 100644 index c62e0b8d822728f7bb0b6e596e20daf8365bc273..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu deleted file mode 100644 index 5e22d67f70027c14ce7038ebe9ed1131b56b1e1f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu deleted file mode 100644 index 1e005b3f018297375570e29729389a2994071064..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 96c4f55afdbdc9baa8b22b952bb1417fcd2ee7cd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu deleted file mode 100644 index 8a92fe291ee1074590274e52742bf791137c3fff..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index f47cb326674059880736e25168dfb1cca0abf97b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu deleted file mode 100644 index 1915feb0463cf7f0d0107d0e9b17fb69e97e2fe4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu deleted file mode 100644 index fbc15776610254c568e1a36b7a3d9c8c49583de8..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu deleted file mode 100644 index 88445691ffbad68dce341979a4be1a506be9d2bd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu deleted file mode 100644 index f7d051a34d37a955c0393497b21a2174b0d99ff7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu deleted file mode 100644 index c83c1741d4f91d383f63d3445e4ccb9b407ae191..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 2e06c89a8c7c9ce1fca71611da21abb7e9b23113..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu deleted file mode 100644 index 46479ec15e14c4c1743d52cc94de58eb67eb154a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 18681ec42b416129faa090292f1edce49d3801f1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu deleted file mode 100644 index d2245aa136aa026b8b39cf5448f32dc930f51e7b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index 022cdd395768e2d029927bbd954020ab698247a2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu deleted file mode 100644 index 67a324d52e805a0856cda8fb154efaa876b2180e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu deleted file mode 100644 index 664f88dbfce0e9ae267764bfa642a1e85353779f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu deleted file mode 100644 index 6bd6b9ab38f6b632b2b16c97643c3b099cb95c1e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu deleted file mode 100644 index 2f4ceaaed53398201d5666028a9e0015e14426bf..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu deleted file mode 100644 index 5fd59af3486edc4115e8c6e0f745dccbe60d6d10..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu deleted file mode 100644 index e0f885b0f72ae4858470b72930b9ccea8ce8eb40..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu deleted file mode 100644 index 6dcda019627eeb34f7622ea4fda6ce9461b2bb95..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 5d20be6d2a7428fc95b7be228cd50b46b9f2eb9b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu deleted file mode 100644 index 47463a7151cbfdcbe3a59f4c009fc249886a38aa..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index 622b5533ce85fbd8ba7818b23b0fdbecdd70c2f8..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu deleted file mode 100644 index c83f44722cda20dea1321a1ca9de61bc8d278e79..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu deleted file mode 100644 index 5c9130f86481367208c67397eac23c9cb7ace3b5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu deleted file mode 100644 index a152022cb650340d0b409f419a5ca587054344d5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu deleted file mode 100644 index ef05aa2038dd46a2bc6b5f43195eca342ac2bd6f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu deleted file mode 100644 index 19fe6d94f7dc6776c8124deaec9db7dd50dc57bb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 6eb2d3d134bc9320a256f1b06164dfca347a7119..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu deleted file mode 100644 index ffbc9982122290c2f296c4c34c2bdc15f3471f90..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 3d35075b48d31bfcdfd43e7407573d981da2f949..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu deleted file mode 100644 index c2af33cf533228efe2bbfe89b82b1ad4578e131a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index e07547c92d0391f3436103188c8dc56a1acfcfe5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu deleted file mode 100644 index 1a04eb01f5e9143020bbd562bd477588e6740881..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu deleted file mode 100644 index da9afc1157167534a6b60718cfac914b977a6817..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu deleted file mode 100644 index 5e63a15515fad74539074fe8781f4a63f0bc1742..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu deleted file mode 100644 index 4134d7d80bbd84711c0920e68ff22bd7f86269b0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu deleted file mode 100644 index 11e3503b0d9263f11bf1f656e86af0ef4e95f90e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu deleted file mode 100644 index 67e39bd7371b72956ff47843220a07566205dd5a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu deleted file mode 100644 index c37844daa562b8d143223db22981c2733ffa9cb6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu deleted file mode 100644 index f0c40e2f89fe9a41826447a4806788d149bd15ed..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu deleted file mode 100644 index 850a47e25c55352d371818a2a9cfaf7025ef1db8..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_paged_sm80.cu" -#include "flash_fwd_hdim64_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu deleted file mode 100644 index 3ed9694908c466b3bd153418a672347359deb640..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu deleted file mode 100644 index 4a16aae66c00c5fc8659dffa2657428fdd37086a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu deleted file mode 100644 index b5b5fc26b280acefe728280db2f809e285f08fc2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 3b29be627edc6b366eb635181619b77b9469d5dc..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 3cc805d00aa83e4ff8e43ef44f9d091ed359d548..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_paged_split_sm80.cu" -#include "flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu deleted file mode 100644 index 5f1c298c4c409464d70fb1a30eedf54102b9b460..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu deleted file mode 100644 index 64895643d2009214caa4dbc2d61785d98103bd68..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index dd508590d66f6ebde86b709a4dbb961f1cdfb9f4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu deleted file mode 100644 index 8411b6fccbdf0d1d1a7d140cdd3924e220df5be0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu deleted file mode 100644 index b5b4f40770e5576d45d68e43fe12de393ed10e76..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu deleted file mode 100644 index b7972b68eb925c3615b4a951779c7b3605f8cfff..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_sm80.cu" -#include "flash_fwd_hdim64_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu deleted file mode 100644 index e608da04b022d3286af44a87b879c19f39f13870..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu deleted file mode 100644 index c69b78ac3b6015cccc0fc7a39df15bae0230afde..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu deleted file mode 100644 index 170cdb5cb8c6e6795aa75db821c8d22a01079c67..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu deleted file mode 100644 index ef0d1e921c1cecf2fbbebd04ea2b7a1d46fe5eb0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu deleted file mode 100644 index f8adb78ec202909f35e68515bbfb58c16a2973c6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_split_sm80.cu" -#include "flash_fwd_hdim64_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu deleted file mode 100644 index 6a7fc29dddaae4145b4d6e8f760032871e0e1c60..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu deleted file mode 100644 index faeb6c487fc6fb459b6261a2a2fdb9f124f9149a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index 655258d5194232d03dbd292c41693fd0c63a8940..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu deleted file mode 100644 index 4bd8ad8f267f5824d862d1343fd116359131850d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index 657820f28547f075a26085c447b6f79aa7a43ace..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu deleted file mode 100644 index cb0955d1a53e1dd4f039dfb35efe4cb0302e8f2e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index 357b64e83b67bf763b32e969505ef9686dc85291..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu deleted file mode 100644 index c1207925864503e30feeae9af161046c4ad57c5f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu deleted file mode 100644 index 21687f8932b72c2d6323e2346ce8bc0d74f62eb3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu deleted file mode 100644 index 4df8ed64d7be8192ee6ac05ac83dd8726796c419..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu deleted file mode 100644 index b601195d7e095577be2cc2f01425cbfc5d6d2ced..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu deleted file mode 100644 index ced475318983e384d1802f0826ff9d38165c52d9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu deleted file mode 100644 index 03090f73cb250b897f3392d41a1b41a28aa3ee39..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu deleted file mode 100644 index d6fe1559ca11407e0a2c75bc6a50d1f501f8c1d2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 7b5ae4a56aa2ea5265bede2dc88e0e4cba3ebe35..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu deleted file mode 100644 index 4dd44519961b256eb33bb4727b0790d9bfc88946..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_paged_sm80.cu" -#include "flash_fwd_hdim64_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu deleted file mode 100644 index 6c603b4dcaaf6e4d6f70e25470a52cc69e1f8406..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu deleted file mode 100644 index 26d25fc1909b3be0251f7b7f4d4a8fa1e867aa9c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu deleted file mode 100644 index 05a0baf18b6a0dd2fbe6bfc08e104d1bb5811f7c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 3a45776537f1ba7ad56b7bf873d413bbdf0caa10..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu deleted file mode 100644 index c7348b075735a222020f73383eba0bfadfd9a36f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_paged_split_sm80.cu" -#include "flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu deleted file mode 100644 index 9b80bae51f6e579b1993f976343b453145cf30de..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu deleted file mode 100644 index f6810efafb865253d66c2797e7fcc39bf446f87a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index 98c018893f1e63c8d6e0e81edf0a622a15d7947b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu deleted file mode 100644 index a10dfaca722bb629cdecb348a3800bd519e3f549..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu deleted file mode 100644 index b912a81443e2688aa2ada96b5181c00d0d9fcbb3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu deleted file mode 100644 index 6e2756085ad05e72e5fce7809847434becdbccdd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_sm80.cu" -#include "flash_fwd_hdim64_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu deleted file mode 100644 index 8603c396e1f4f8edb28f5aaf632f10dec948d75f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu deleted file mode 100644 index dc55dbc66aaabfddc6422614eb0055940923163b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu deleted file mode 100644 index ef48844972a43736f44e7550adb9739c4287d178..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu deleted file mode 100644 index b1c0ead6e5c48dd41d6356f4cf56da69d64322a2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu deleted file mode 100644 index fa4c8095449af058afbd457723e6f3ea2c28ca1d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_split_sm80.cu" -#include "flash_fwd_hdim64_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu deleted file mode 100644 index 5d76d0fff0466c5eaf3c23a0b9abd1c7ee31c5be..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu deleted file mode 100644 index 44ea823d272b7f76d6be2b29aed8d7d15d2d6fe2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu deleted file mode 100644 index 30fe623508b14c4584c56d72456be3d26559f9d2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu deleted file mode 100644 index 6eb12dc80a60248305c0e2dc5e297380f0073361..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu deleted file mode 100644 index b806fc9d501dcb79c05728a02b8564bab147f8d2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu deleted file mode 100644 index 989d42b270840723d84286bb3474263557800d52..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_bf16_paged_sm80.cu" -#include "flash_fwd_hdim96_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu deleted file mode 100644 index 8f0a26da03a116ea3b25982f40eeff3823cbea2b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu deleted file mode 100644 index 6de2819a17293dab4d23d8cd875350b4d1627e8c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu deleted file mode 100644 index 16927295b8203e797f17bdbd5998677072b64e0f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 0841307209230557940e4a27d36afaed02865df3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 72654c12104cd08cd23c83321b9f2249a4af843b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_bf16_paged_split_sm80.cu" -#include "flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu deleted file mode 100644 index 7d4dcdc293bdfe23c1167f2af9da53d1b4745006..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu deleted file mode 100644 index b4dfbf7f8b2c58bd7a3c2a8573fc2a9b11badb5d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index 1fa048752dcea50711ea473c27ab0b33e9fe30b2..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu deleted file mode 100644 index e0b6a75e63581cef18e85418c7f26579037dd1d0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu deleted file mode 100644 index e257b42f79ccec1a9647ea73b6b882aa86a2f668..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu deleted file mode 100644 index 31268ea3dc58503cc7f8e42c2cb276494bdcafe5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_bf16_sm80.cu" -#include "flash_fwd_hdim96_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu deleted file mode 100644 index f97ab4733a08e66c7604e14a6b75de4628ecd952..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu deleted file mode 100644 index cee43ef94cdd9e3bb8803284f1d2137fb9d8b835..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu deleted file mode 100644 index 0442e1f94b533edc8d7cc548713b3ec84b6691cb..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu deleted file mode 100644 index bc71fa9e71fc89dd5c029faec448557def185ed4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu deleted file mode 100644 index 435607c03cf2ccd32bfe5928b47fccda4eb5b333..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_bf16_split_sm80.cu" -#include "flash_fwd_hdim96_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu deleted file mode 100644 index b61dd71885d9d13ed6e6e4e9883606f4483a8afd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu deleted file mode 100644 index f47e1f5cdacf18f2758b0ff74a7208113ba5f17c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index 215752f1b0621a3823bf8c854d287d35710afc27..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu deleted file mode 100644 index 207afc79242fec88b1f2dca95c1214b736613ed4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index 6c38c083384b5968c80d53aacf9d15b4bd8bd01c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu deleted file mode 100644 index dc2eb35dc2984ac0d4c8599a2177ded3c92a1c28..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index f04e8bca6f375f75e0973f673c66d099e4ab2ea7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu deleted file mode 100644 index 2697f6910e7ec9e137c78c38c141facc62ba583e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu deleted file mode 100644 index e7a98b2e6ee45014645d64d4300b8141f12d0acc..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu deleted file mode 100644 index 98fb39c86ee5f593eb746dffaf8e3ff2227a27e7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu deleted file mode 100644 index cb938ad93b05592b94ba5d9fa181a70a54fe7125..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu deleted file mode 100644 index e2dc45c79c6cd21a2f5b22e75fbbbbfe9207dbb7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu deleted file mode 100644 index 64f99c05a3224317360ece6215e24a6297709602..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu deleted file mode 100644 index 3fdbbf23bacd71d9f876cb2dcca8e6db5c17c16f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu deleted file mode 100644 index ffe202ee394bce5d86c156bb854fed0c51dffb05..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu deleted file mode 100644 index 3ce908b2f97a02445e9bcc026b633f24c6ad8b17..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_fp16_paged_sm80.cu" -#include "flash_fwd_hdim96_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu deleted file mode 100644 index 42740f0228b3e6295901bc3f99495817e1dd613b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu deleted file mode 100644 index 829929980d0d605c5a8c1555d2d8a436f4ca56f5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu deleted file mode 100644 index d6a330432a4579785ecbaa3b338664ee9bd1ad9d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 39c774e6f7797bd9158a813af1c09e8829076bad..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu deleted file mode 100644 index 2d504315825da06c5b06cdf2c79d0f8dfb32d758..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_fp16_paged_split_sm80.cu" -#include "flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu deleted file mode 100644 index bc54be11e6c3d446dae7401a0f159cb26ddfb280..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu deleted file mode 100644 index a68790500d8701f24d1a3e937dbd9733f715f867..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index 3bca3065c7f8117ea903e2b755ae379c4a6952f7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu deleted file mode 100644 index 985692b9fa1aa6040910b54a4a6e38cf526e0e21..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu deleted file mode 100644 index 3c99cb6b5a007154f2f18071b5c67e0a63b2c149..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu deleted file mode 100644 index 1c9565de0d6d6f22c21040837753e16386033b6c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_fp16_sm80.cu" -#include "flash_fwd_hdim96_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu deleted file mode 100644 index cf77a1ae819172eb1662f48464a1dcc6d298681d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu deleted file mode 100644 index f9a46a44dd5c0c96a9b50ca3268836aa77fa401b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu deleted file mode 100644 index 9b4dbbba58aefd8cf4498b1333d98e75c10ac81f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_SM8x -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu deleted file mode 100644 index da5373fd13e6fc64ef45fa159a286dcff8af9e33..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_launch_template.h" - -#ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu deleted file mode 100644 index 921fcd4292abf6d6dc0d9b6e8b16677f8f3c8325..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim96_fp16_split_sm80.cu" -#include "flash_fwd_hdim96_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu deleted file mode 100644 index 8b659e8321b7ca811f5422d4fa09c97d605f2020..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu deleted file mode 100644 index c84d02b6d04c7daafaeb3e5eccc1e401bbe18712..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_paged_sm90.cu" -#include "flash_fwd_hdim96_bf16_paged_sm90.cu" -#include "flash_fwd_hdim128_bf16_paged_sm90.cu" -#include "flash_fwd_hdim192_bf16_paged_sm90.cu" -#include "flash_fwd_hdim256_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu deleted file mode 100644 index 6aaf7d12f5677c5beace8562b950955a4da394e6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu deleted file mode 100644 index 117121414197181eb850e2b73abd3c26ffd5c4f7..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index 6175723086cee981ed4cb7ff133db55db0c13f3a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu deleted file mode 100644 index 2aac1970b1b0322f086e01df74f7d5875c202479..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_sm90.cu" -#include "flash_fwd_hdim96_bf16_sm90.cu" -#include "flash_fwd_hdim128_bf16_sm90.cu" -#include "flash_fwd_hdim192_bf16_sm90.cu" -#include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index be0c5af080bf713b053d426d58c3753352d04568..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu deleted file mode 100644 index fd5893c59f432e312162d9a8da24592c35ed09d1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim96_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim128_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim192_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu deleted file mode 100644 index bcde9c9458291f1ec9c1c7cf4705c9ad601a95a1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_split_sm90.cu" -#include "flash_fwd_hdim96_bf16_split_sm90.cu" -#include "flash_fwd_hdim128_bf16_split_sm90.cu" -#include "flash_fwd_hdim192_bf16_split_sm90.cu" -#include "flash_fwd_hdim256_bf16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu deleted file mode 100644 index 160eb3a18e4817c1aed0f422cf5bba535ba5f981..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu deleted file mode 100644 index 28819a690a3b4341557d994ecd4021d9be0bd81e..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_packgqa_sm90.cu" -#include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" -#include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" -#include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" -#include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu deleted file mode 100644 index 933ad98271957d41f3d43f487d72a37467185c76..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_paged_sm90.cu" -#include "flash_fwd_hdim96_e4m3_paged_sm90.cu" -#include "flash_fwd_hdim128_e4m3_paged_sm90.cu" -#include "flash_fwd_hdim192_e4m3_paged_sm90.cu" -#include "flash_fwd_hdim256_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index a934f7d9924eb7a1bfc7b5ed93a1e6c1c206e94d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu" -#include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" -#include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" -#include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu deleted file mode 100644 index 8475e878ae2d017f97d46bb27955d1fa91ae5e58..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_paged_split_sm90.cu" -#include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" -#include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" -#include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" -#include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index dd1405b17f06c5d59a4a49c662ead4e4f8b9504c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu deleted file mode 100644 index 7e7d806c6d575c095fdb76d8fb82d27014cba1d4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_sm90.cu" -#include "flash_fwd_hdim96_e4m3_sm90.cu" -#include "flash_fwd_hdim128_e4m3_sm90.cu" -#include "flash_fwd_hdim192_e4m3_sm90.cu" -#include "flash_fwd_hdim256_e4m3_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index f973a4e411d039e97f8ce6c855f25045fbcc1803..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu deleted file mode 100644 index 30390838d3901f1d95db65d5adfb263f4c9a6c30..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_softcap_sm90.cu" -#include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" -#include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" -#include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" -#include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu deleted file mode 100644 index 0b629bd2b3224a52238a7814768e94a6d3d398f1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_split_sm90.cu" -#include "flash_fwd_hdim96_e4m3_split_sm90.cu" -#include "flash_fwd_hdim128_e4m3_split_sm90.cu" -#include "flash_fwd_hdim192_e4m3_split_sm90.cu" -#include "flash_fwd_hdim256_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu deleted file mode 100644 index 818c7fafb7a0eac09887061adbcfe3737e777f9b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_e4m3_split_softcap_sm90.cu" -#include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" -#include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" -#include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu deleted file mode 100644 index 6652824d0752d4247e407256e7c5ea4fded56c4a..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu deleted file mode 100644 index 05d11e2e2583660775720695fe88e21650d38d4d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_paged_sm90.cu" -#include "flash_fwd_hdim96_fp16_paged_sm90.cu" -#include "flash_fwd_hdim128_fp16_paged_sm90.cu" -#include "flash_fwd_hdim192_fp16_paged_sm90.cu" -#include "flash_fwd_hdim256_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu deleted file mode 100644 index b638138eb2623a7dc475431063a1cdc120ab6f44..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu deleted file mode 100644 index 3619a2175f086159c98c22e5cca4e885c7320ae4..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index 3a408ceacbd6ffa11d6a4502d83f6f397407c1d1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu deleted file mode 100644 index eec11be9162c565066a8e8851fff349cef89d910..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_sm90.cu" -#include "flash_fwd_hdim96_fp16_sm90.cu" -#include "flash_fwd_hdim128_fp16_sm90.cu" -#include "flash_fwd_hdim192_fp16_sm90.cu" -#include "flash_fwd_hdim256_fp16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index ca2a1e1b843c2b79d3bc15754720653f5aa94897..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu deleted file mode 100644 index 8cf31a8a85f76445d754416827c665af8da17df5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim96_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim128_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim192_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu deleted file mode 100644 index 5ee7ace63ace05d2b5a1b875652a2313f93c9fb0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_split_sm90.cu" -#include "flash_fwd_hdim96_fp16_split_sm90.cu" -#include "flash_fwd_hdim128_fp16_split_sm90.cu" -#include "flash_fwd_hdim192_fp16_split_sm90.cu" -#include "flash_fwd_hdim256_fp16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu deleted file mode 100644 index 4da0ee704eb8d736dcb44a5dfa4b13f62e82b8e0..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu deleted file mode 100644 index ddd8bf07c4a601c34d04b7c0175093d6b755defd..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu deleted file mode 100644 index c9494c4f1d23fecdda5f219b432d0ebae9db9112..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_paged_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu deleted file mode 100644 index 4b2ec583cfdb4d89899945e2f88e33b5c010b096..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu deleted file mode 100644 index 306722d45865b08d7f59a1e9655822764fb2eb01..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu deleted file mode 100644 index e44b2d2465463f8aac1ed28631f27b94eb38bee6..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu deleted file mode 100644 index d52417daef3c6d8de6af777322e308ec104babf9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu deleted file mode 100644 index 6428c461aa9c485bcb9856b3f8dfe2a41a5a3522..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu deleted file mode 100644 index d0df6306e28f059c35350be22b9e71c212b4b94b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu deleted file mode 100644 index e116d3ea7c7f98b778f27b7465b8976c664c4282..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_split_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu deleted file mode 100644 index bededf4a7d8f93d4503dcfac297c079d3d75880f..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu deleted file mode 100644 index 526a51fb71e0d98ff104187027313cd24fdfec4c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu deleted file mode 100644 index 4e5d9cc4fe2f2cc873f789d4bdf47b12a7aa808d..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu deleted file mode 100644 index f553af139f22e59c1dfc6abdee828d9c8e2afe00..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu deleted file mode 100644 index aa2a8260d2500c0db89600641b2948f81e68219c..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu deleted file mode 100644 index bbc4449ba21cedad486767b2d985769e56193924..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu deleted file mode 100644 index 02ca85ad672277db5e85b827effb8a4737d08201..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu deleted file mode 100644 index d090fde972b919a50afd37b109cdc66e11226093..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu deleted file mode 100644 index d48f60ad7e2fc8b133ed5bdb629d3aba93d1ae37..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu deleted file mode 100644 index 9dda19d1ceaf8fa43aaca9e87a1051509c15ee05..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu deleted file mode 100644 index f3e51fc9ebdf1f37eb329a5aa39b0d7e326ac0e9..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu deleted file mode 100644 index ea531027938b8f74f86a2f1a0c9fc8195b99e0d3..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu deleted file mode 100644 index 10d86e5e99c4cd379f94fba6127355a1b5204c04..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_paged_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu deleted file mode 100644 index 375197ef75e54d468d8fed3524abaf50cd9fb075..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu deleted file mode 100644 index 4fc4831cf58d6f188cb21a9b3ca9824d276d49a5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu deleted file mode 100644 index a3d94a163a9a2158c4b9ddf1c468c428f4847198..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu deleted file mode 100644 index 9663103ae117876652399c211bbb778a4d048069..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu deleted file mode 100644 index b7d2b07ca840dfd9768811edec1848838aafac8b..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu deleted file mode 100644 index 471b5abaafc555ccdf8eeceb706e649358786db1..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu deleted file mode 100644 index 10f72182fa97128ff83db68d029ae4a2cbdce2d5..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_split_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu deleted file mode 100644 index 54db60c23b162904546c89e2f249deab2c59e942..0000000000000000000000000000000000000000 --- a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -// Splitting the different template instantiations to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include "flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/mainloop_bwd_sm80.hpp b/flash-attn/mainloop_bwd_sm80.hpp deleted file mode 100644 index 0a79670f4757e8ff3a9a9cc7ce34026d79e68d1b..0000000000000000000000000000000000000000 --- a/flash-attn/mainloop_bwd_sm80.hpp +++ /dev/null @@ -1,901 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include - -#include "cute/tensor.hpp" - -#include "seqlen.h" -#include "mask.h" -#include "mask.h" -#include "softmax.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct CollectiveMainloopBwdSm80 { - - static constexpr int kStages = Stages; - static constexpr int kStages_dO = Stages_dO; - static_assert(kStages >= kStages_dO); - using TileShape_MNK = TileShape_MNK_; - using Element = Element_; - using ElementAccum = ElementAccum_; - using ArchTag = ArchTag_; - static constexpr bool Is_causal = Is_causal_; - static constexpr bool Is_local = Is_local_; - static constexpr bool Has_softcap = Has_softcap_; - static constexpr bool Varlen = Varlen_; - static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup; - - static constexpr bool SdP_swapAB = SdP_swapAB_; - static constexpr bool dKV_swapAB = dKV_swapAB_; - static constexpr bool dQ_swapAB = dQ_swapAB_; - - static constexpr bool Q_dO_same_stages = kStages == kStages_dO; - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - - using SeqlenInfo_t = flash::SeqlenInfoQK; - using BlockMN_t = flash::BlockMN; - - static_assert(ArchTag::kMinComputeCapability >= 80); - - static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; - - static constexpr int NumMmaThreads = NumMmaWarps * cutlass::NumThreadsPerWarp; - static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler - - using MMA_Atom_Arch = std::conditional_t< - ArchTag::kMinComputeCapability >= 80, - std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >, - MMA_Atom - >; - - static_assert(NumMmaWarps % AtomLayoutMSdP == 0); - static_assert(NumMmaWarps % AtomLayoutNdKV == 0); - static_assert(NumMmaWarps % AtomLayoutMdQ == 0); - static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarps && SdP_swapAB && !dKV_swapAB; - static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarps && AtomLayoutMdQ == NumMmaWarps && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS - - using AtomLayoutSdP = std::conditional_t< - !SdP_swapAB, - Layout, Int, _1>>, - Layout, Int, _1>> - >; - static constexpr bool MmaSdPEvenN = ((!SdP_swapAB ? kBlockN : kBlockM) / size<1>(AtomLayoutSdP{})) % 16 == 0; - using TiledMmaSdP = TiledMMA< - MMA_Atom_Arch, - AtomLayoutSdP, - Tile(AtomLayoutSdP{}))>, Int<(MmaSdPEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutSdP{}))>, _16>>; - - using AtomLayoutdKV = std::conditional_t< - !dKV_swapAB, - Layout, Int, _1>>, - Layout, Int, _1>> - >; - static constexpr bool MmadKVEvenN = ((!dKV_swapAB ? kHeadDim : kBlockN) / size<1>(AtomLayoutdKV{})) % 16 == 0; - using TiledMmadKV = TiledMMA< - MMA_Atom_Arch, - AtomLayoutdKV, - Tile(AtomLayoutdKV{}))>, Int<(MmadKVEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdKV{}))>, _16>>; - - using AtomLayoutdQ = std::conditional_t< - !dQ_swapAB, - Layout, Int, _1>>, - Layout, Int, _1>> - >; - static constexpr bool MmadQEvenN = ((!dQ_swapAB ? kHeadDim : kBlockM) / size<1>(AtomLayoutdQ{})) % 16 == 0; - using TiledMmadQ = TiledMMA< - MMA_Atom_Arch, - AtomLayoutdQ, - Tile(AtomLayoutdQ{}))>, Int<(MmadQEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdQ{}))>, _16>>; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each - // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - - static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); - static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); - - // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. - // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. - // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension - // changes the layout. - using SmemLayoutAtomQdO = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = - decltype(tile_to_shape(SmemLayoutAtomQdO{}, - make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutdO = - decltype(tile_to_shape(SmemLayoutAtomQdO{}, - make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomKV = decltype( - composition(Swizzle{}, - // TODO: FA2 has a slightly different layout, does it matter? - Layout>, - Stride, _1>>{})); - using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); - - using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); - - // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = kBlockN % 64 == 0 ? 64 : (kBlockN % 32 == 0 ? 32 : 16); - static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); - // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); - static constexpr int kSwizzlePdS = 3; - using SmemLayoutAtomPdS = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutPdS = decltype(tile_to_shape( - SmemLayoutAtomPdS{}, - make_shape(Int{}, Int{}))); - - // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, - // it's still a valid smem address. - using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; - using SmemLayoutLSEMma = std::conditional_t< - SdP_swapAB, - cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, - cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> - >; - - // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutQt = - decltype(cute::composition(SmemLayoutQ{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), - make_stride(Int{}, _1{}, Int{})))); - using SmemLayoutdOt = - decltype(cute::composition(SmemLayoutdO{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), - make_stride(Int{}, _1{}, Int{})))); - using SmemLayoutKt = - decltype(cute::composition(SmemLayoutK{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), - make_stride(Int{}, _1{})))); - using SmemLayoutPdSt = - decltype(cute::composition(SmemLayoutPdS{}, - make_layout(make_shape(Int{}, Int{}), - make_stride(Int{}, _1{})))); - - // Thread layout, 256 or 384 threads per row - using R2SLayoutAtomdQaccum = Layout>>; - using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, - Layout>{})); // Val layout, 1 vals per store - - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; - // For the case where the N dimension of MmaSdP is divisible by 8 but not by 16 - using SmemCopyAtomHalf = Copy_Atom; - // For the case where the N dimension of MmadQ is divisible by 8 but not by 16 - using SmemCopyAtomTransposedHalf = Copy_Atom; - // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. - // If PdS_major is MN, then we need to "transpose" the write. - // TODO: check this write - using R2SCopyAtomPdS = Copy_Atom, Element>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using GmemCopyStruct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL_ZFILL, - AutoVectorizingCopyWithAssumedAlignment<128> - >; - using GmemCopyAtom = Copy_Atom; - - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(GmemCopyAtom{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per read - using GmemCopyAtomLSE = Copy_Atom; - using GmemLayoutAtomLSE = Layout>>; - using GmemTiledCopyLSE = decltype(make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{}, - Layout>{})); // Val layout, 4 vals per store - // So that we don't have to check if we overshot kBlockM when we load Q - // static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); - - using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) - using StrideQKV = cute::Stride; - using ShapeLSE = cute::Shape; // (seqlen, head, batch) - using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) - using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; - - // These are tuned for speed. They don't affect correctness. - // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 - // this helps quite a bit to not have to do causal masking for most of the iterations. - // For hdim 192, separating masking iterations results in register spills. - // static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; - static constexpr bool SeparateMaskingIterations = false; - // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then - // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each - // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep - // statistic for 2 rows. - // static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; - // static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; - static constexpr bool ShuffleLSE = SdP_swapAB && false; - static constexpr bool ShuffledPsum = SdP_swapAB && false; - - static constexpr bool Share_QV_Smem = V_in_regs; - using SmemP_t = std::conditional_t, cute::array_aligned>>; - - struct TensorStorageSharedQV : cute::aligned_struct<128> { - cute::array_aligned> smem_k; - union { - cute::array_aligned> smem_v; - cute::array_aligned> smem_q; - }; - cute::array_aligned> smem_do; - cute::array_aligned, 128> smem_lse; - cute::array_aligned, 128> smem_dpsum; - SmemP_t smem_p; - cute::array_aligned> smem_ds; - }; - - struct TensorStorageSeparateQV : cute::aligned_struct<128> { - cute::array_aligned> smem_k; - cute::array_aligned> smem_v; - cute::array_aligned> smem_q; - cute::array_aligned> smem_do; - cute::array_aligned, 128> smem_lse; - cute::array_aligned, 128> smem_dpsum; - SmemP_t smem_p; - cute::array_aligned> smem_ds; - }; - - using TensorStorage = std::conditional_t; - - // Host side kernel arguments - struct Arguments { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQKV const stride_Q; - Element const* const ptr_K; - ShapeQKV const shape_K; - StrideQKV const stride_K; - Element const* const ptr_V; - StrideQKV const stride_V; - Element const* const ptr_dO; - StrideQKV const stride_dO; - ElementAccum* const ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum const stride_dQaccum; - float const* const ptr_LSE_log2; - ShapeLSE const shape_LSE; - StrideLSE const stride_LSE_log2; - float const* const ptr_dPsum; - StrideLSE const stride_dPsum; - float const softmax_scale; - int const window_size_left, window_size_right; - float const softcap_val; - int const num_batch; - int* const dq_semaphore; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - }; - - // Device side kernel params - struct Params { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQKV const stride_Q; - Element const* const ptr_K; - ShapeQKV const shape_K; - StrideQKV const stride_K; - Element const* const ptr_V; - StrideQKV const stride_V; - Element const* const ptr_dO; - StrideQKV const stride_dO; - ElementAccum* const ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum stride_dQaccum; - cutlass::FastDivmod qhead_per_khead_divmod; - float const* const ptr_LSE_log2; - ShapeLSE const shape_LSE; - StrideLSE const stride_LSE_log2; - float const* const ptr_dPsum; - StrideLSE const stride_dPsum; - float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right; - float const softcap_val; - int const num_batch; - int *const dq_semaphore; - int const *const cu_seqlens_q = nullptr; - int const *const cu_seqlens_k = nullptr; - int const *const seqused_q = nullptr; - int const *const seqused_k = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } - // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - // Right after this, we multiply by log2(e) before applying exp2. - // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) - // (assigning it to params.softmax_scale_log2). - // In the backward, we need to multiply by - // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. - // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale - // (the original softmax_scale) at the end. - return {args.ptr_Q, args.shape_Q, args.stride_Q, - args.ptr_K, args.shape_K, args.stride_K, - args.ptr_V, args.stride_V, - args.ptr_dO, args.stride_dO, - args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, - cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, - args.softmax_scale, - !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, - !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.num_batch, args.dq_semaphore, - args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; - } - - template - CUTLASS_DEVICE bool - mma(Params const& params, - FrgTensordKV& tdKrdK, - FrgTensordKV& tdVrdV, - int thread_idx, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); - - int n_block = get<0>(block_coord); - int bidh = get<1>(block_coord); - int bidb = get<2>(block_coord); - SeqlenInfo_t seqlen_info{ - bidb, get<0>(params.shape_Q), size<0>(params.shape_K), - params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k - }; - auto m_block_min_max = BlockMN_t::get_m_block_min_max( - seqlen_info, n_block, bidb, - params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); - int const m_block_min = get<0>(m_block_min_max); - int const m_block_max = get<1>(m_block_min_max); - // It's possible to have m_block_max <= m_block_min. Exit early - if constexpr (Is_causal || Is_local || Varlen) { - if (m_block_max <= m_block_min) { return false; } - } - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); - Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); - Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); - Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); - Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); - Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); - Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); - Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); - Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); - Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); - Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); - Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); - - bool const is_varlen_q = Varlen && params.cu_seqlens_q; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; - int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); - Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_Q, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); - Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), - params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen_q ? bidb : 0); - - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) - Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) - Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) - Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) - Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) - Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) - Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) - - GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); - auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation - GmemTiledCopyLSE gmem_tiled_copy_lse; - auto gmem_thr_copy_lse = gmem_tiled_copy_lse.get_thread_slice(thread_idx); - R2STiledCopydQaccum r2s_tiled_copy_dQaccum; - auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tdOgdO = gmem_thr_copy_QKV.partition_S(gdO); - Tensor tdOsdO = gmem_thr_copy_QKV.partition_D(sdO); - Tensor tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE); - Tensor tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE); - Tensor tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum); - Tensor tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum); - // We can reuse r2s_thr_copy_dQaccum for this partitioning - Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } - - TiledMmaSdP tiled_mma_SdP; - TiledMmadKV tiled_mma_dKV; - TiledMmadQ tiled_mma_dQ; - - auto thr_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); - auto thr_mma_dKV = tiled_mma_dKV.get_thread_slice(thread_idx); - auto thr_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx); - - // Allocate "fragments/descriptors" - // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, - // because some partition_fragment_A/B don't compile. - // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function - Tensor tdPrV = mma_partition_fragment_AB(thr_mma_SdP, sV); - - // Copy Atom retiling - auto smem_copy_atom_SdP_B = cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}); - auto smem_tiled_copy_QdO = cute::conditional_return(make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP), make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP)); - auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(thread_idx); - Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); - Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); - - auto smem_tiled_copy_KV = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP), make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP)); - auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(thread_idx); - Tensor tSsK = smem_thr_copy_KV.partition_S(sK); - Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); - - auto r2s_tiled_copy_PdS = make_tiled_copy_C(R2SCopyAtomPdS{}, tiled_mma_SdP); - auto r2s_thr_copy_PdS = r2s_tiled_copy_PdS.get_thread_slice(thread_idx); - Tensor tPsP = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sP, sPt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor tdSsdS = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sdS, sdSt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(r2s_thr_copy_PdS); print(sP); printf("\n"); print(sPt); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } - - auto smem_copy_atom_dKV_B = cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}); - auto smem_tiled_copy_PdSt = cute::conditional_return(make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV), make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV)); - auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(thread_idx); - Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); - Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); - - auto smem_tiled_copy_QdOt = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV)); - auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(thread_idx); - Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); - Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); - - auto smem_tiled_copy_dS = cute::conditional_return( - make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_dQ), - make_tiled_copy_B(cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}), tiled_mma_dQ)); - auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(thread_idx); - Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); - - auto smem_tiled_copy_Kt = cute::conditional_return( - make_tiled_copy_B(cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}), tiled_mma_dQ), - make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dQ)); - auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(thread_idx); - Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); - - // thr_mma_SdP.partition_C(sLSEMma) has shape (MMA=4, MMA_M, MMA_N, PIPE), we only take the col indices - // or row indices, depending on whether SdP_swapAB. - Tensor tSsLSEMma = logical_divide(thr_mma_SdP.partition_C(sLSEMma), Shape<_2>{}); // (2, 2, MMA_M, MMA_N, PIPE) - Tensor tSsLSE = group_modes<0, 2>(cute::conditional_return( - tSsLSEMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) - tSsLSEMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) - Tensor tSsdPsumMma = logical_divide(thr_mma_SdP.partition_C(sdPsumMma), Shape<_2>{}); - Tensor tSsdPsum = group_modes<0, 2>(cute::conditional_return( - tSsdPsumMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) - tSsdPsumMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } - // If we want to split the stats among the 8 threads that share the same rows. - static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tSsLSE))::value, 8); - - // Predicates - Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); - Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } - Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); - Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); - - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - - flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod - ); - - { - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - // Predicates - Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); - Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } - // Do we need bound check to make sure the row doesn't go above kBlockN - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; - // static_assert(EvenN); // It simplifies the loading of K and V - // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit - // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. - // int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN - // ? seqlen_info.seqlen_k - n_block * kBlockN - // : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)); - // // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockN dimension - // flash::copy( - // gmem_tiled_copy_QKV, tVgV, tVsV, t0KVcKV, tKVpKV, seqlenk_row_limit); - int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); - #pragma unroll - for (int m = 0; m < size<1>(tVsV); ++m) { - // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { - bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; - #pragma unroll - for (int k = 0; k < size<2>(tVsV); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); - } - } - } - if constexpr (V_in_regs) { flash::cp_async_fence(); } - // flash::copy( - // gmem_tiled_copy_QKV, tKgK, tKsK, t0KVcKV, tKVpKV, seqlenk_row_limit); - #pragma unroll - for (int m = 0; m < size<1>(tKsK); ++m) { - if (EvenN || m < size<1>(tKsK) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { - bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; - #pragma unroll - for (int k = 0; k < size<2>(tKsK); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); - } - } - } - flash::cp_async_fence(); - } - - if constexpr (V_in_regs) { - flash::cp_async_wait<1>(); - __syncthreads(); - Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); - Tensor tdPsV_copy_view = smem_thr_copy_KV.partition_S(sV); - cute::copy(smem_tiled_copy_KV, tdPsV_copy_view, tdPrV_copy_view); - __syncthreads(); // Sync to avoid loading Q to smem_q, which overlaps with smem_v - } - - // Do we need bound check to make sure the row doesn't go above kBlockM - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; - - auto load_Q_LSE = [&] (int const m_block, int const smem_pipe_write) { - // if (cute::thread0()) { printf("Inside load_Q_LSE, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } - Tensor tQsQ_cur = tQsQ(_, _, _, smem_pipe_write); - Tensor tQgQ_cur = tQgQ(_, _, _, m_block); - // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit - // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. - // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM - // ? seqlen_info.seqlen_q - m_block * kBlockM - // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); - // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockM dimension - // flash::copy( - // gmem_tiled_copy_QKV, tQgQ(_, _, _, m_block), tQsQ_cur, t0QcQ, tQpQ, seqlenq_row_limit); - int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); - #pragma unroll - for (int m = 0; m < size<1>(tQsQ); ++m) { - // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked - if (EvenM || m < size<1>(tQsQ) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { - bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; - #pragma unroll - for (int k = 0; k < size<2>(tQsQ); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tQgQ_cur(_, m, k), tQsQ_cur(_, m, k)); - } - } - } - Tensor tLSEgLSE_cur = tLSEgLSE(_, _, m_block); - Tensor tLSEsLSE_cur = tLSEsLSE(_, _, smem_pipe_write); - // We made sure LSE length is padded so we read `kBlockM` elements so that all - // elements in sLSE are filled. Without this we might have uninitialized sLSE values. - #pragma unroll - for (int m = 0; m < size<1>(tLSEsLSE); ++m) { - if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { - cute::copy(gmem_tiled_copy_lse, tLSEgLSE_cur(_, m), tLSEsLSE_cur(_, m)); - } - } - }; - - auto load_dO_dPsum = [&] (int const m_block, int const smem_pipe_write) { - // if (cute::thread0()) { printf("Inside load_dO_dPsum, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } - Tensor tdOsdO_cur = tdOsdO(_, _, _, smem_pipe_write); - Tensor tdOgdO_cur = tdOgdO(_, _, _, m_block); - // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM - // ? seqlen_info.seqlen_q - m_block * kBlockM - // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); - // flash::copy( - // gmem_tiled_copy_QKV, tdOgdO(_, _, _, m_block), tdOsdO_cur, t0QcQ, tQpQ, seqlenq_row_limit); - int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); - #pragma unroll - for (int m = 0; m < size<1>(tdOsdO); ++m) { - // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked - if (EvenM || m < size<1>(tdOsdO) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { - bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; - #pragma unroll - for (int k = 0; k < size<2>(tdOsdO); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); - } - } - } - Tensor tLSEgdPsum_cur = tLSEgdPsum(_, _, m_block); - Tensor tLSEsdPsum_cur = tLSEsdPsum(_, _, smem_pipe_write); - #pragma unroll - for (int m = 0; m < size<1>(tLSEsdPsum); ++m) { - if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { - cute::copy(gmem_tiled_copy_lse, tLSEgdPsum_cur(_, m), tLSEsdPsum_cur(_, m)); - } - } - }; - - int m_block = m_block_min; - - // Note, using the for_each() function here to ensure `stage` is of type Int. - for_each(make_int_sequence{}, [&] (auto stage) { - static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; - static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; - if constexpr (!Is_last_stage || kStages == 1) { - if (Is_first_stage || m_block + stage < m_block_max) { - load_Q_LSE(m_block + stage, stage); - } - } - // We want the fence outside the if statement to have a fixed number of cp.async commits. - // so that we can wait with the correct number of outstanding commits. - cute::cp_async_fence(); - if constexpr (stage < kStages_dO) { - if (Is_first_stage || m_block + stage < m_block_max) { - load_dO_dPsum(m_block + stage, stage); - } - cute::cp_async_fence(); - } - }); - - int smem_pipe_read = 0, smem_pipe_read_do = 0, smem_pipe_write = kStages - 1, smem_pipe_write_do = 0; - - auto load_Q_next = [&] { - // if (cute::thread0()) { printf("m_block = %d, m_block_max = %d, smem_pipe_write = %d\n", m_block, m_block_max, smem_pipe_write); } - if (m_block + (kStages > 1 ? kStages - 1 : 1) < m_block_max) { - load_Q_LSE(m_block + (kStages > 1 ? kStages - 1 : 1), kStages > 1 ? smem_pipe_write : 0); - } - cute::cp_async_fence(); - }; - - auto load_dO_next = [&] { - // int smem_pipe_write_do_cur = Q_dO_same_stages ? smem_pipe_write : smem_pipe_write_do; - if (m_block + kStages_dO < m_block_max) { - // load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do_cur : 0); - load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do : 0); - } - cute::cp_async_fence(); - }; - - clear(tdKrdK); - clear(tdVrdV); - - auto bwd_step = [&](int m_block, auto mask_fn) { - Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); - clear(tSrS); - flash::cp_async_wait<(kStages > 1) ? 1 : 0>(); - __syncthreads(); - Tensor tSrQ = mma_partition_fragment_AB(thr_mma_SdP, sQ(_, _, _0{})); - Tensor tSrK = mma_partition_fragment_AB(thr_mma_SdP, sK); - // if (cute::thread0()) { print(tiled_mma_SdP); print(tSrS); printf("\n"); print(tSrQ); printf("\n"); print(tSrK); printf("\n"); print(tSsQ); printf("\n"); print(tSsK); printf("\n"); } - flash::gemm_sm80( - tSrS, tSrQ, tSrK, tSsQ(_, _, _, kStages > 1 ? smem_pipe_read : 0), tSsK, - tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, nullptr /*hook*/); - Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tSsLSE(_, _0{})), make_tensor(Int{})); - if constexpr (!ShuffleLSE) { - cute::copy(tSsLSE(_, kStages > 1 ? smem_pipe_read : 0), tLSErLSE); - } else { - #pragma unroll - for (int i = 0; i < kStatsPerThread; ++i) { - // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values - tLSErLSE(i) = tSsLSE((thread_idx % 32) / 4 + i * 8, kStages > 1 ? smem_pipe_read : 0); - } - } - if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } - - // Reshape tSrS from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); - // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh - // if (cute::thread0()) { print_tensor(scores); } - auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); - mask_fn(tSrS, m_block); - #pragma unroll - for (int mi = 0; mi < size<0>(scores); ++mi) { - float const lse_scaled = [&] { - if constexpr (!ShuffleLSE) return tLSErLSE(mi); - else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); - }(); - #pragma unroll - for (int ni = 0; ni < size<1>(scores); ++ni) { - scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); - } - } - - Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); - clear(tdPrdP); - int smem_pipe_read_do_cur = Q_dO_same_stages ? smem_pipe_read : smem_pipe_read_do; - flash::cp_async_wait<(kStages_dO > 1) ? 1 : 0>(); - __syncthreads(); - auto hook = cute::conditional_return<(kStages > 1)>(load_Q_next, nullptr); - Tensor tdPrdO = mma_partition_fragment_AB(thr_mma_SdP, sdO(_, _, _0{})); - Tensor tdPrV_cur = cute::conditional_return(tdPrV, mma_partition_fragment_AB(thr_mma_SdP, sV)); - flash::gemm_sm80( - tdPrdP, tdPrdO, tdPrV_cur, tdPsdO(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tdPsV, - tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, hook); - Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tSsdPsum(_, _0{})), make_tensor(Int{})); - if constexpr (!ShuffledPsum) { - cute::copy(tSsdPsum(_, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tLSErdPsum); - } else { - #pragma unroll - for (int i = 0; i < kStatsPerThread; ++i) { - tLSErdPsum(i) = tSsdPsum((thread_idx % 32) / 4 + i * 8, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); - } - } - - // Reshape tdPrdP from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { - float const dP_sum_cur = [&] { - if constexpr (!ShuffledPsum) return tLSErdPsum(mi); - else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); - }(); - #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); - if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } - } - } - // if (cute::thread0()) { print_tensor(dS); } - - // Convert scores from fp32 to fp16/bf16 - Tensor rP = make_tensor_like(tSrS); - flash::convert_type_out(tSrS, rP); - if constexpr (!Mma_dKV_is_RS) { - Tensor tPaP = r2s_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(r2s_tiled_copy_PdS, tPaP, tPsP); - } - Tensor rdS = make_tensor_like(tdPrdP); - flash::convert_type_out(tdPrdP, rdS); - if constexpr (!Mma_dKV_is_RS) { __syncthreads(); } // Make sure P is written - // For hdim 64, It's faster to write to smem_dS first before the dV gemm - Tensor tdSadS = r2s_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(r2s_tiled_copy_PdS, tdSadS, tdSsdS); - - Tensor tdVrdO = mma_partition_fragment_AB(thr_mma_dKV, sdOt(_, _, _0{})); - Tensor tdVsdO_cur = tdVsdOt(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); - if constexpr (Mma_dKV_is_RS) { - Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); - flash::gemm_rs_sm80(tdVrdV, tdVrP, tdVrdO, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); - } else { - Tensor tdVrP = mma_partition_fragment_AB(thr_mma_dKV, sPt); - flash::gemm_sm80( - tdVrdV, tdVrP, tdVrdO, tdVsPt, tdVsdO_cur, - tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, nullptr); - } - // if (cute::thread0()) { print_tensor(tdVrdV); } - __syncthreads(); // make sure sdS is written - auto do_mma_dQ = [&] (auto hook) { - Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); - clear(tdQrdQ); - Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); - Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); - flash::gemm_sm80( - tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, - // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); - smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); - // if (cute::thread0()) { print_tensor(tdQrdQ); } - // We can reuse r2s_thr_copy_dQaccum for this partitioning - Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); - Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); - static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); - #pragma unroll - for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } - }; - // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration - if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); } - Tensor tdKrQ = mma_partition_fragment_AB(thr_mma_dKV, sQt(_, _, _0{})); - Tensor tdKsQ_cur = tdKsQt(_, _, _, kStages > 1 ? smem_pipe_read : 0); - if constexpr (Mma_dKV_is_RS) { - Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); - flash::gemm_rs_sm80(tdKrdK, tdKrdS, tdKrQ, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); - } else { - Tensor tdKrdS = mma_partition_fragment_AB(thr_mma_dKV, sdSt); - flash::gemm_sm80( - tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur, - tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next)); - } - if constexpr (kStages == 1) { - __syncthreads(); - do_mma_dQ(load_Q_next); - } - // if (cute::thread0()) { print_tensor(tdKrdK); } - - smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; - smem_pipe_read_do = smem_pipe_read_do < kStages_dO - 1 ? smem_pipe_read_do + 1 : 0; - smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; - smem_pipe_write_do = smem_pipe_write_do < kStages_dO - 1 ? smem_pipe_write_do + 1 : 0; - - }; - - // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 - // this helps quite a bit to not have to do causal masking for most of the iterations. - if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; - CUTLASS_PRAGMA_NO_UNROLL - for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { - bwd_step(m_block, mask_fn); - } - } - - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations - ? m_block_max - : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); - - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; - CUTLASS_PRAGMA_NO_UNROLL - for (; m_block < m_block_max_before_local_mask; ++m_block) { - bwd_step(m_block, mask_fn); - } - - if constexpr (Is_local && SeparateMaskingIterations) { - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; - CUTLASS_PRAGMA_NO_UNROLL - for (; m_block < m_block_max; ++m_block) { - bwd_step(m_block, mask_fn); - } - } - - // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } - #pragma unroll - for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } - - return true; - } - -}; - -} // namespace flash diff --git a/flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp b/flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp deleted file mode 100644 index 71cfb02046952f6aa97e3cc58c9a2d6fbd748158..0000000000000000000000000000000000000000 --- a/flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ /dev/null @@ -1,1029 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include -#include "cutlass/pipeline/pipeline.hpp" - -#include "cute/tensor.hpp" - -#include "cutlass/gemm/collective/builders/sm90_common.inl" - -#include "named_barrier.hpp" -#include "seqlen.h" -#include "block.h" -#include "mask.h" -#include "softmax.h" -#include "utils.h" -#include "copy_sm90_bulk_reduce.hpp" - -namespace flash { - -using namespace cute; - -template -struct CollectiveMainloopBwdSm90 { - - static constexpr int kStages = Stages; - static constexpr int kStages_dO = Stages_dO; - static constexpr int kStages_dS = Stages_dS; - static_assert(kStages >= kStages_dO); - static_assert(Stages_dS == 1 || Stages_dS == kStages); - static_assert(!Mma_dP_is_RS || SdP_swapAB_); // If Mma_dP_is_RS, we need SdP_SwapAB - using ClusterShape = ClusterShape_; - using TileShape_MNK = TileShape_MNK_; - using Element = Element_; - using ElementAccum = ElementAccum_; - using ArchTag = ArchTag_; - static constexpr bool Is_causal = Is_causal_; - static constexpr bool Is_local = Is_local_; - static constexpr bool Has_softcap = Has_softcap_; - static constexpr bool Varlen = Varlen_; - - static constexpr bool SdP_swapAB = SdP_swapAB_; - static constexpr bool dKV_swapAB = dKV_swapAB_; - static constexpr bool dQ_swapAB = dQ_swapAB_; - - static constexpr bool Q_dO_same_stages = kStages == kStages_dO; - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - - using SeqlenInfo_t = flash::SeqlenInfoQK; - using BlockMN_t = flash::BlockMN; - - static_assert(ArchTag::kMinComputeCapability >= 90); - static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); - - static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp * 2; - - static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0); - static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0); - static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0); - static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB; - static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS - - static constexpr GMMA::Major PdS_Major = GMMA::Major::K; - // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN; - static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K; - - using TileShapeAtomSdP = std::conditional_t< - !SdP_swapAB, - Shape, Int, Int>, - Shape, Int, Int> - >; - using AtomLayoutSdP = std::conditional_t< - !SdP_swapAB, - Layout, Int, _1>>, - Layout, Int, _1>> - >; - using TiledMmaSdP = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutSdP{})); - - using TiledMmadPRS = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(), - AtomLayoutSdP{})); - - using TileShapeAtomdKV = std::conditional_t< - !dKV_swapAB, - Shape, Int, Int>, - Shape, Int, Int> - >; - using AtomLayoutdKV = std::conditional_t< - !dKV_swapAB, - Layout, Int, _1>>, - Layout, Int, _1>> - >; - using TiledMmadKV = decltype(cute::make_tiled_mma( - std::conditional_t< - Mma_dKV_is_RS, - decltype(cute::GMMA::rs_op_selector()), - decltype(cute::GMMA::ss_op_selector()) - >{}, - AtomLayoutdKV{})); - - using TileShapeAtomdQ = std::conditional_t< - !dQ_swapAB, - Shape, Int, Int>, - Shape, Int, Int> - >; - using AtomLayoutdQ = std::conditional_t< - !dQ_swapAB, - Layout, Int, _1>>, - Layout, Int, _1>> - >; - using TiledMmadQ = decltype(cute::make_tiled_mma( - std::conditional_t< - Mma_dQ_is_RS, - decltype(cute::GMMA::rs_op_selector()), - decltype(cute::GMMA::ss_op_selector()) - >{}, - AtomLayoutdQ{})); - - // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. - // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. - // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension - // changes the layout. - using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); // for dKV_Mma - using SmemLayoutQ = - decltype(tile_to_shape(SmemLayoutAtomQdO{}, - make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutdO = - decltype(tile_to_shape(SmemLayoutAtomQdO{}, - make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); - using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector, - Int>()); - using SmemLayoutPdS = decltype(tile_to_shape( - SmemLayoutAtomPdS{}, - make_shape(Int{}, Int{}, Int{}), - std::conditional_t, cute::Step<_2, _1, _3>>{})); - - // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80 - // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, - // it's still a valid smem address. - using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; - using SmemLayoutLSEMma = std::conditional_t< - SdP_swapAB, - cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, - cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> - >; - - // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutQt = - decltype(cute::composition(SmemLayoutQ{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), - make_stride(Int{}, _1{}, Int{})))); - using SmemLayoutdOt = - decltype(cute::composition(SmemLayoutdO{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), - make_stride(Int{}, _1{}, Int{})))); - using SmemLayoutKt = - decltype(cute::composition(SmemLayoutK{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), - make_stride(Int{}, _1{})))); - using SmemLayoutPdSt = - decltype(cute::composition(SmemLayoutPdS{}, - make_layout(make_shape(Int{}, Int{}, Int{}), - make_stride(Int{}, _1{}, Int{})))); - - // Thread layout, 256 or 384 threads per row - // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately. - using R2SLayoutAtomdQaccum = Layout, Int>>; - using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, - Layout>{})); // Val layout, 4 vals per store - using SmemLayoutdQaccum = Layout, Int>>; - - static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads; - // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. - // If PdS_major is MN, then we need to "transpose" the write. - using SmemCopyAtomPdS = Copy_Atom< - std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN), - std::conditional_t, - std::conditional_t - >, - Element - >; - - using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{}))); - using GmemTiledCopyKV = cute::SM90_TMA_LOAD; - - using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) - using StrideQKV = cute::Stride; - using ShapeLSE = cute::Shape; // (seqlen, head, batch) - using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) - using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; - - using TMA_QdO = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQdO{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), - take<0, 2>(SmemLayoutQ{}), - TileShape_MNK{}, - ClusterShape{})); // mcast along N mode for this M load, if any - - using TMA_K = decltype(make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), - SmemLayoutK{}, - TileShape_MNK{}, - ClusterShape{})); // no mcast for KV - - using TMA_V = decltype(make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), - SmemLayoutV{}, - TileShape_MNK{}, - ClusterShape{})); // no mcast for KV - - using MainloopPipeline = typename cutlass::PipelineTmaAsync; - using PipelineState = typename MainloopPipeline::PipelineState; - using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync; - using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; - - // Set the bytes transferred in this TMA transaction (may involve multiple issues) - static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesK = static_cast(size(SmemLayoutK{}) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesV = static_cast(size(SmemLayoutV{}) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesLSE = static_cast(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v / 8); - - // These are tuned for speed. They don't affect correctness. - // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 - // this helps quite a bit to not have to do causal masking for most of the iterations. - // For hdim 192, separating masking iterations results in register spills. - static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; - // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then - // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each - // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep - // statistic for 2 rows. - static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; - static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; - static constexpr bool dQacc_use_TMA = kHeadDim < 256; - // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can - // do atomic add on one half before doing the other half of the MMA, to reduce register pressure. - static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2; - static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma"); - - static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); - static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); - // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA - static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128; - static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{}); - static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment"); - - // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment? - using SmemdQacc_t = std::conditional_t, cute::array_aligned>>; - using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; - struct TensorStorage : cute::aligned_struct { - cute::array_aligned, SmemAlignmentQKVdO> smem_k; - cute::array_aligned, SmemAlignmentV> smem_v; - SmemdQacc_t smem_dqacc; - cute::array_aligned, SmemAlignmentQKVdO> smem_q; - cute::array_aligned, SmemAlignmentQKVdO> smem_do; - cute::array_aligned, 128> smem_lse; - cute::array_aligned, 128> smem_dpsum; - SmemP_t smem_p; - cute::array_aligned, SmemAlignmentdS> smem_ds; - }; - - // Host side kernel arguments - struct Arguments { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQKV const stride_Q; - Element const* const ptr_K; - ShapeQKV const shape_K; - StrideQKV const stride_K; - Element const* const ptr_V; - StrideQKV const stride_V; - Element const* const ptr_dO; - StrideQKV const stride_dO; - ElementAccum* const ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum const stride_dQaccum; - float const* const ptr_LSE_log2; - ShapeLSE const shape_LSE; - StrideLSE const stride_LSE_log2; - float const* const ptr_dPsum; - StrideLSE const stride_dPsum; - float const softmax_scale; - int const window_size_left, window_size_right; - float const softcap_val; - int const num_batch; - int* const dq_semaphore; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - }; - - // Device side kernel params - struct Params { - ShapeQKV const shape_Q; - ShapeQKV const shape_K; - ElementAccum* const ptr_dQaccum; - ShapedQaccum const shape_dQaccum; - StridedQaccum stride_dQaccum; - cutlass::FastDivmod qhead_per_khead_divmod; - TMA_QdO tma_load_Q, tma_load_dO; - TMA_K tma_load_K; - TMA_V tma_load_V; - float const* const ptr_LSE_log2; - ShapeLSE const shape_LSE; - StrideLSE const stride_LSE_log2; - float const* const ptr_dPsum; - StrideLSE const stride_dPsum; - float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right; - float const softcap_val; - int const num_batch; - int* const dq_semaphore; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); - TMA_QdO tma_load_Q = make_tma_copy_A_sm90( - GmemTiledCopyQdO{}, - mQ, - SmemLayoutQ{}(_, _, _0{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along N mode for this M load, if any - Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); - TMA_QdO tma_load_dO = make_tma_copy_A_sm90( - GmemTiledCopyQdO{}, - mdO, - SmemLayoutdO{}(_, _, _0{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along N mode for this M load, if any - Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); - TMA_K tma_load_K = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - mK, - SmemLayoutK{}, - TileShape_MNK{}, - ClusterShape{}); // no mcast for KV - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); - TMA_V tma_load_V = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - mV, - SmemLayoutV{}, - TileShape_MNK{}, - ClusterShape{}); // no mcast for KV - if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } - // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - // Right after this, we multiply by log2(e) before applying exp2. - // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) - // (assigning it to params.softmax_scale_log2). - // In the backward, we need to multiply by - // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. - // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale - // (the original softmax_scale) at the end. - return {args.shape_Q, args.shape_K, - args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, - cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, - args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, - args.softmax_scale, - !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, - !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.num_batch, args.dq_semaphore, - args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& params) { - cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); - } - - template - CUTLASS_DEVICE void - load(Params const& params, - MainloopPipeline pipeline_q, - MainloopPipeline_dO pipeline_do, - PipelineState& smem_pipe_write, - PipelineState_dO& smem_pipe_write_do, - SharedStorage &shared_storage, - SchedulerPrefetch const& scheduler_prefetch, - cute::tuple block_coord - ) { - - auto [n_block, bidh, bidb] = block_coord; - SeqlenInfo_t seqlen_info{ - bidb, get<0>(params.shape_Q), size<0>(params.shape_K), - params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k - }; - auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( - seqlen_info, n_block, bidb, - params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); - // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. - if constexpr (Is_causal || Is_local || Varlen) { - if (m_block_max <= m_block_min) { - scheduler_prefetch(); - return; - } - } - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); - Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); - Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); - - int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); - - // Prepare the TMA loads - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - bool const is_varlen_q = Varlen && params.cu_seqlens_q; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); - Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); - - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) - Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) - Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) - Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) - Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) - Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) - - Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); - Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); - Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{})); - Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{})); - // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, - // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) - // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, - // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) - auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y); - auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y); - Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); - Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); - Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); - Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); - auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, - group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA) - auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, - group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA) - auto bulk_copy = Copy_Traits{}; - - uint16_t mcast_mask_qdo = 0; - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{})); - } - } - - int m_block = m_block_min; - - int lane_predicate = cute::elect_one_sync(); - - if (lane_predicate) { - pipeline_q.producer_acquire(smem_pipe_write); - copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), - tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index())); - copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), - gLSE(_, m_block), sLSE(_, smem_pipe_write.index())); - } - - // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready - // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); - - if (lane_predicate) { - // Copy K tile and V tile from GMEM to SMEM. - shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV); - copy(params.tma_load_K.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK); - copy(params.tma_load_V.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV); - - #pragma unroll (kHeadDim < 256 ? 2 : 1) - for (; m_block < m_block_max - 1; ++m_block) { - // If Q and dO have the same number of stages, we can use the same pipeline state variable - // to reduce registers - PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); - pipeline_do.producer_acquire(smem_pipe_write_do_cur); - copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), - tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); - copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), - gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); - if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } - ++smem_pipe_write; - pipeline_q.producer_acquire(smem_pipe_write); - copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), - tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index())); - copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), - gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index())); - } - } - scheduler_prefetch(); - if (lane_predicate) { - PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); - pipeline_do.producer_acquire(smem_pipe_write_do_cur); - copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), - tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); - copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), - gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); - if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } - ++smem_pipe_write; - } - if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, - PipelineState& smem_pipe_write) { - static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages"); - // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write - PipelineState smem_pipe_write_do = smem_pipe_write; - // Issue the epilogue waits - if (cute::elect_one_sync()) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was still inverted from make_producer_start_state - */ - pipeline_q.producer_tail(smem_pipe_write); - pipeline_do.producer_tail(smem_pipe_write_do); - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, - PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) { - // Issue the epilogue waits - if (cute::elect_one_sync()) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was still inverted from make_producer_start_state - */ - pipeline_q.producer_tail(smem_pipe_write); - pipeline_do.producer_tail(smem_pipe_write_do); - } - } - - template - CUTLASS_DEVICE void - store_dq(Params const& params, - SharedStorage &shared_storage, - cute::tuple block_coord - ) { - if constexpr (!dQacc_use_TMA) { return; } - - auto [n_block, bidh, bidb] = block_coord; - SeqlenInfo_t seqlen_info{ - bidb, get<0>(params.shape_Q), size<0>(params.shape_K), - params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k - }; - auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( - seqlen_info, n_block, bidb, params.window_size_left, - params.window_size_right, 0 /*sink_token_length*/); - // It's possible to have m_block_max <= m_block_min. Exit early - if constexpr (Is_causal || Is_local || Varlen) { - if (m_block_max <= m_block_min) { return; } - } - - Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); - static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum); - - bool const is_varlen = Varlen && params.cu_seqlens_q; - Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), - params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) - Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) - - int const num_batch = params.num_batch; - int const num_head = get<2>(params.shape_Q); - int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh; - using Barrier = cutlass::GenericBarrier; - bool const lane_predicate = cute::elect_one_sync(); - int m_block = m_block_min; - #pragma unroll 2 - for (; m_block < m_block_max; ++m_block) { - if constexpr (Deterministic) { - Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); - } - #pragma unroll - for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem - if (lane_predicate) { - SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); - tma_store_arrive(); - } - } - // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. - for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { - if (lane_predicate) { tma_store_wait(); } - cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to - }); - if constexpr (Deterministic) { - Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); - } - } - if constexpr (Is_local && Deterministic) { - constexpr int kBlockM = get<0>(TileShape_MNK{}); - int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); - #pragma unroll 2 - for (; m_block < m_block_global_max; ++m_block) { - Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); - } - } - } - - CUTLASS_DEVICE void - mma_init() { - // We're not currently using this bc we're not using persistent scheduler - // // Tell producer (warp 0) that smem_k and smem_v are ready - // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - if constexpr (dQacc_use_TMA) { - if (warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to - } - } - } - - template - CUTLASS_DEVICE bool - mma(Params const& params, - MainloopPipeline pipeline_q, - MainloopPipeline_dO pipeline_do, - PipelineState& smem_pipe_read, - PipelineState_dO& smem_pipe_read_do, - FrgTensordKV& tdKrdK, - FrgTensordKV& tdVrdV, - int thread_idx, - int &work_idx, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); - - int n_block = get<0>(block_coord); - int bidb = get<2>(block_coord); - SeqlenInfo_t seqlen_info{ - bidb, get<0>(params.shape_Q), size<0>(params.shape_K), - params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k - }; - auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( - seqlen_info, n_block, bidb, params.window_size_left, - params.window_size_right, 0 /*sink_token_length*/); - // It's possible to have m_block_max <= m_block_min. Exit early - if constexpr (Is_causal || Is_local || Varlen) { - if (m_block_max <= m_block_min) { return false; } - } - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); - Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); - Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); - Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); - Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); - Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); - Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP); - Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); - Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt); - Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); - Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS); - Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); - Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt); - Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); - Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); - Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); - - static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and - stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and - size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); - constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(make_shape(Int{}), - make_stride(Int{})); - Layout warp_group_thread_layout_dq = make_layout(make_shape(Int{}), - make_stride(Int{})); - - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMmaSdP tiled_mma_SdP; - using TiledMmadP = std::conditional_t; - TiledMmadP tiled_mma_dP; - TiledMmadKV tiled_mma_dKV; - TiledMmadQ tiled_mma_dQ; - - auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx)); - auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); - auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx)); - - auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); - auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); - - R2STiledCopydQaccum r2s_tiled_copy_dQaccum; - auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); - Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ); - // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); } - - // Allocate "fragments/descriptors" - // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, - // because some partition_fragment_A/B don't compile. - // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function - Tensor tSrQ = mma_partition_fragment_AB(wg_mma_SdP, sQ); - Tensor tSrK = mma_partition_fragment_AB(wg_mma_SdP, sK); - Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_SdP, sdO); - Tensor tdPrV = mma_partition_fragment_AB(wg_mma_dP, sV); - Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dKV, sdOt); - Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dKV, sQt); - Tensor tdQrdS = mma_partition_fragment_AB(wg_mma_dQ, sdS); - Tensor tdQrK = mma_partition_fragment_AB(wg_mma_dQ, sKt); - - Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } - - // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices - // or row indices, depending on whether SdP_swapAB. - Tensor tLSEsLSE = cute::conditional_return( - group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) - group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) - Tensor tLSEsdPsum = cute::conditional_return( - group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), - group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } - // If we want to split the stats among the 8 threads that share the same rows. - static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8); - - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - }; - - int bidh = get<1>(block_coord); - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - - // For the case where we do atomicAdd directly to gdQaccum instead of using TMA - bool const is_varlen = Varlen && params.cu_seqlens_q; - Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), - params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) - Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) - // We can reuse r2s_thr_copy_dQaccum for this partitioning - Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } - - flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod - ); - - int m_block = m_block_min; - - clear(tdKrdK); - clear(tdVrdV); - // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero; - - cutlass::ConsumerToken barrier_token = static_cast(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2)); - if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); } - - if constexpr (Mma_dP_is_RS) { - using SmemCopyAtomV = Copy_Atom; - auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); - Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV); - Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV)); - cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view); - } - - auto bwd_step = [&](int m_block, auto mask_fn) { - Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); - consumer_wait(pipeline_q, smem_pipe_read); - flash::gemm(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); - Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor(Int{})); - if constexpr (!ShuffleLSE) { - cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE); - } else { - #pragma unroll - for (int i = 0; i < kStatsPerThread; ++i) { - // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values - tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index()); - } - } - Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); - PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return(smem_pipe_read, smem_pipe_read_do); - consumer_wait(pipeline_do, smem_pipe_read_do_cur); - flash::gemm(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP); - warpgroup_wait<1>(); - if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } - - // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); - // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh - auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); - mask_fn(tSrS, m_block); - #pragma unroll - for (int mi = 0; mi < size<0>(scores); ++mi) { - float const lse_scaled = [&] { - if constexpr (!ShuffleLSE) return tLSErLSE(mi); - else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); - }(); - #pragma unroll - for (int ni = 0; ni < size<1>(scores); ++ni) { - scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); - } - } - - Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor(Int{})); - if constexpr (!ShuffledPsum) { - cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum); - } else { - #pragma unroll - for (int i = 0; i < kStatsPerThread; ++i) { - tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index()); - } - } - - warpgroup_wait<0>(); - // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) - Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { - float const dP_sum_cur = [&] { - if constexpr (!ShuffledPsum) return tLSErdPsum(mi); - else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); - }(); - #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); - if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } - } - } - - // Convert scores from fp32 to fp16/bf16 - Tensor rP = make_tensor_like(tSrS); - flash::convert_type_out(tSrS, rP); - if constexpr (!Mma_dKV_is_RS) { - // Need to sync to make sure P has already been used in the previous iteration before writing new values - if constexpr (kStages_dS == 1) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); - } - Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); - } - Tensor rdS = make_tensor_like(tdPrdP); - flash::convert_type_out(tdPrdP, rdS); - // If there's double buffering on dS, we don't need to sync here. - // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. - // But because both WGs have to sync at the end of the loop and double buffering, - // this race condition is not possible. - // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and - // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. - if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) { - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); - } - // For hdim 64, It's faster to write to smem_dS first before the dV gemm - Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); - - if constexpr (!Slice_dQKV_Mma) { - // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure - if constexpr (Mma_dKV_is_RS) { - Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); - flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); - } else { - Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); - Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); - } - // SMEM fence to make sure sdS is written before it's read by WGMMA - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); - Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); - Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ - - if constexpr (Mma_dKV_is_RS) { - Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); - flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); - } else { - Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); - Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); - } - if constexpr (dQacc_use_TMA) { - int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1; - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem - Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); - cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem - } else { - // We can reuse r2s_thr_copy_dQaccum for this partitioning - Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); - Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); - static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); - #pragma unroll - for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } - } - - } else { // Slice_dQKV_Mma - - static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS)); - Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); - Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); - - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); - Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); - Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); - Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); - Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); - #pragma unroll - for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } - - Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); - Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); - flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); - pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO - - flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - #pragma unroll - for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } - - flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); - } - - warpgroup_wait<0>(); - pipeline_q.consumer_release(smem_pipe_read); // release Q - ++smem_pipe_read; - if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; } - }; - - // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 - // this helps quite a bit to not have to do causal masking for most of the iterations. - if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; - CUTLASS_PRAGMA_NO_UNROLL - for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { - bwd_step(m_block, mask_fn); - } - } - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations - ? m_block_max - : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); - - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; - CUTLASS_PRAGMA_NO_UNROLL - for (; m_block < m_block_max_before_local_mask; ++m_block) { - bwd_step(m_block, mask_fn); - } - - if constexpr (Is_local && SeparateMaskingIterations) { - auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; - CUTLASS_PRAGMA_NO_UNROLL - for (; m_block < m_block_max; ++m_block) { - bwd_step(m_block, mask_fn); - } - } - - // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } - #pragma unroll - for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } - - if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; } - ++work_idx; - return true; - } - -}; - -} // namespace flash diff --git a/flash-attn/mainloop_fwd_sm80.hpp b/flash-attn/mainloop_fwd_sm80.hpp deleted file mode 100644 index 4ce024f346a219d3b61156576293154616f8cef7..0000000000000000000000000000000000000000 --- a/flash-attn/mainloop_fwd_sm80.hpp +++ /dev/null @@ -1,855 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include - -#include "cute/tensor.hpp" - -#include "seqlen.h" -#include "block.h" -#include "mask.h" -#include "pack_gqa.h" -#include "paged_kv.h" -#include "rotary.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct CollectiveMainloopFwdSm80 { - - static constexpr int kStages = Stages; - static_assert(kStages > 0, "kStages must be greater than 0"); - using TileShape_MNK = TileShape_MNK_; - using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; - using Element = Element_; - using ElementAccum = ElementAccum_; - using ElementSAux = ElementSAux_; - using ArchTag = ArchTag_; - static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; - static constexpr bool Is_causal = Is_causal_; - static constexpr bool Is_local = Is_local_; - static constexpr bool Has_softcap = Has_softcap_; - static constexpr bool Varlen = Varlen_; - static constexpr bool PagedKV = PagedKV_; - static constexpr bool AppendKV = AppendKV_; - static constexpr bool PackGQA = PackGQA_; - static constexpr bool Split = Split_; - static constexpr bool Transpose_V = Is_FP8; - - static_assert(ArchTag::kMinComputeCapability >= 80); - - static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; - using BlockMN_t = flash::BlockMN; - - using MMA_Atom_Arch = std::conditional_t< - ArchTag::kMinComputeCapability >= 80, - std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >, - MMA_Atom - >; - using TiledMma = TiledMMA< - MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - Tile, _16, _16>>; - - static constexpr int NumMmaThreads = size(TiledMma{}); - static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each - // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - - static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); - static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); - using SmemLayoutAtomQKV = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{}))); - - using SmemLayoutK = decltype(tile_to_shape( - SmemLayoutAtomQKV{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutV = decltype(tile_to_shape( - SmemLayoutAtomQKV{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutVt = decltype( - composition(SmemLayoutV{}, - make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), - Step<_2, _1, _3>{}))); - - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using GmemCopyAtom = Copy_Atom, - AutoVectorizingCopyWithAssumedAlignment<128> - >, Element>; - - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(GmemCopyAtom{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per read - // So that we don't have to check if we overshot kBlockM when we load Q - static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); - - // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of - // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), - // each thread will load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); - static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad; - static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where - // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. - static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, "kGmemThreadsPerRowAppend must divide NumThreadsPerWarp"); - using GmemLayoutAtomAppend = Layout, Int>, - Stride, _1>>; - // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication - static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend"); - using GmemTiledCopyAppendKV = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtomAppend{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - - using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) - using StrideQK = cute::Stride; - using StrideV = StrideQK; - // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) - using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; - using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; - using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) - using StridePageTable = cute::Stride; - using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) - using StrideRotary = cute::Stride; - using StrideDescale = cute::Stride; - - static constexpr bool Share_QV_Smem = Q_in_regs; - - struct TensorStorageSharedQV : cute::aligned_struct<128> { - union { - cute::array_aligned> smem_v; - cute::array_aligned> smem_q; - }; - cute::array_aligned> smem_k; - }; - - struct TensorStorageSeparateQV : cute::aligned_struct<128> { - cute::array_aligned> smem_v; - cute::array_aligned> smem_k; - cute::array_aligned> smem_q; - }; - - using TensorStorage = std::conditional_t; - - // Host side kernel arguments - struct Arguments { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQK const stride_Q; - Element* const ptr_K; // Not Element const* since we might append to KV cache in-place - ShapeQKV const shape_K; - StrideQK const stride_K; - Element* const ptr_V; - int32_t const headdim_v; - StrideV const stride_V; - Element const* const ptr_K_new; - ShapeQKV const shape_K_new; - StrideQK const stride_K_new; - Element const* const ptr_V_new; - StrideV const stride_V_new; - Element const* const ptr_Qv; - StrideQK const stride_Qv; - Element const* const ptr_rotary_cos; - ShapeRotary const shape_rotary; - StrideRotary const stride_rotary_cos; - Element const* const ptr_rotary_sin; - StrideRotary const stride_rotary_sin; - bool const is_rotary_interleaved; - int const* const ptr_pagetable; - ShapePageTable const shape_pagetable; - StridePageTable const stride_pagetable; - float const softmax_scale; - float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; - StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1; - float const softcap_val; - int const num_splits; - int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; - ElementSAux const* const ptr_S_aux = nullptr; - }; - - // Device side kernel params - struct Params { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQK const stride_Q; - ShapeQPacked const shape_Q_packed; - StrideQPacked const stride_Q_packed; - Element* const ptr_K; - ShapeQKV const shape_K; - StrideQK const stride_K; - Element* const ptr_V; - int32_t const headdim_v; - StrideV const stride_V; - Element const* const ptr_K_new; - ShapeQKV const shape_K_new; - StrideQK const stride_K_new; - Element const* const ptr_V_new; - StrideV const stride_V_new; - Element const* const ptr_rotary_cos; - ShapeRotary const shape_rotary; - StrideRotary const stride_rotary_cos; - Element const* const ptr_rotary_sin; - StrideRotary const stride_rotary_sin; - bool const is_rotary_interleaved; - int const* const ptr_pagetable; - ShapePageTable const shape_pagetable; - StridePageTable const stride_pagetable; - cutlass::FastDivmod page_size_divmod; - cutlass::FastDivmod qhead_per_khead_divmod; - float const softmax_scale_log2; - float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; - StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - float const softcap_val; - int const window_size_left, window_size_right; - int const num_splits; - int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; - ElementSAux const* const ptr_S_aux = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) - int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); - auto const shape_Q_packed = cute::conditional_return( - args.shape_Q, - make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) - ); - auto const stride_Q_packed = cute::conditional_return( - args.stride_Q, - make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) - ); - if (get<1>(args.shape_rotary) > 0) { - assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); - } - assert(args.num_splits >= 1); - // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - // Right after this, we multiply by log2(e) before applying exp2. - // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) - // (assigning it to params.softmax_scale_log2). - return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, - args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, - args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, - args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, - args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, - cutlass::FastDivmod(int(get<0>(args.shape_K))), - cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, - args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, - !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, - !Split ? 1 : args.num_splits, - args.kv_batch_idx, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.ptr_S_aux}; - } - - template - CUTLASS_DEVICE bool - mma(Params const& params, - FrgTensorO& tOrO, - Softmax& softmax, - int const thread_idx, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "O tensor must be rmem resident."); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda - int const m_block = get<0>(block_coord); - int const bidh = get<1>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto n_block_min_max = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - int const n_block_min = get<0>(n_block_min_max); - int const n_block_max = get<1>(n_block_min_max); - // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier - if constexpr (Is_causal || Is_local || Varlen || Split) { - if (n_block_max <= n_block_min) { return false; } - } - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); - - bool const is_varlen_q = Varlen && params.cu_seqlens_q; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; - - int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; - Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - - GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); - auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation - - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_slice(thread_idx); - - // Allocate "fragments/descriptors" - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); - - // Copy Atom retiling - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); - auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx); - auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // Predicates - Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); - Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } - - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int n_block = n_block_max - 1; - - // Prologue: load Q, K, V - // If persistent, we don't need to wait for the previous work_idx to finish - // since we assume that all MMA threads sync in the epilogue before writing to smem_o. - // So any thread gets there, all threads must have finished the previous MMA and at least started - // writing to smem_o. - // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v - if constexpr (Share_QV_Smem) { __syncthreads(); } - if constexpr (!PackGQA) { - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); - Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } - // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit - // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy( - gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})) - ); - } else { - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>; - PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block); - } - cute::cp_async_fence(); - - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; - PagedKVManager_t paged_kv_manager( - params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, - params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, - params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, - bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, - 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ - ); - - auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { - static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; - if constexpr (!PagedKV) { - // Do we need bound check to make sure the row doesn't go above kBlockN - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; - Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write); - // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit - // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. - int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN - ? seqlen_info.seqlen_k - n_block * kBlockN - : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN))); - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy( - gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit); - } else { - paged_kv_manager.template load_page_table(n_block); - paged_kv_manager.template load_K(n_block, sK(_, _, smem_pipe_write)); - } - }; - - auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { - static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; - if constexpr (!PagedKV) { - // Do we need bound check to make sure the row doesn't go above kBlockN - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; - Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write); - // We don't call flash::copy since it doesn't support bound checking - // to not overshot kBlockN when writing to smem. - Tensor tVgV_cur = tVgV(_, _, _, n_block); - int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); - #pragma unroll - for (int m = 0; m < size<1>(tVsV); ++m) { - // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { - bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; - #pragma unroll - for (int k = 0; k < size<2>(tVsV); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k)); - } - } - } - } else { - paged_kv_manager.template load_V(n_block, sV(_, _, smem_pipe_write)); - } - }; - - auto preprocess_Q = [&] { - if constexpr (!AppendKV) { - flash::cp_async_wait(); - } else { - if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - using Rotary_t = Rotary; - Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, - params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, - seqlen_info.seqlen_rotary); - int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; - if (params.is_rotary_interleaved) { - auto [tRrCos, tRrSin] = cute::conditional_return( - rotary.template load_cos_sin(m_block), - rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) - ); - flash::cp_async_wait(); - __syncthreads(); - rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead); - } else { - auto [tRrCosCont, tRrSinCont] = cute::conditional_return( - rotary.template load_cos_sin(m_block), - rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) - ); - flash::cp_async_wait(); - __syncthreads(); - rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); - } - } else { - flash::cp_async_wait(); - } - } - - if constexpr (Q_in_regs) { - __syncthreads(); - Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ); - cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); - } - }; - - // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and - // read from smem_q to registers, then load V. - // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q. - - if constexpr (Share_QV_Smem) { - load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/); - cute::cp_async_fence(); - preprocess_Q(); - __syncthreads(); // Make sure all threads have read smem_q before loading V - } - - // For persistent, make sure all threads have finished reading smem_o - if constexpr (!Share_QV_Smem) { __syncthreads(); } - // Note, using the for_each() function here to ensure `stage` is of type Int. - for_each(make_int_sequence{}, [&] (auto stage) { - static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; - static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; - if constexpr (!Share_QV_Smem || !Is_first_stage) { - if (Is_first_stage || n_block - stage >= n_block_min) { - load_K(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); - } - // We want the fence outside the if statement to have a fixed number of cp.async commits. - // so that we can wait with the correct number of outstanding commits. - cute::cp_async_fence(); - } - if constexpr (!Is_last_stage) { - if (Is_first_stage || n_block - stage >= n_block_min) { - load_V(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - } - }); - - if constexpr (!Share_QV_Smem) { preprocess_Q(); } - - flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod - ); - - float softcap_val = params.softcap_val; - if constexpr (Has_softcap && Is_FP8) { - float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; - float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; - softcap_val *= q_descale * k_descale; - } - // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - // -inf to e.g. -50.0, which can affect the attention softmax. - auto scoremod_premask_fn = [&](auto& tSrS) { - if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } - }; - - int smem_pipe_read = 0, smem_pipe_write = kStages - 1; - - auto load_K_next = [&] { - if (n_block - kStages >= n_block_min) { - load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - }; - - auto sync = [&] { - flash::cp_async_wait(); - __syncthreads(); - }; - - clear(tOrO); - - auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { - static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; - static constexpr bool Check_inf = decltype(check_inf_type)::value; - Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); - clear(tSrS); - sync(); - auto load_V_next = [&] { - if (n_block - kStages + 1 >= n_block_min) { - load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - }; - Tensor tSrQ_cur = cute::conditional_return(tSrQ, thr_mma.partition_fragment_A(sQ)); - Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{})); - flash::gemm_sm80( - tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0), - tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next - ); - smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; - scoremod_premask_fn(tSrS); - // Faster to load_K before gemm if we only have 1 stage - if constexpr (kStages == 1) { sync(); load_K_next(); } - mask_fn(tSrS, n_block); - Tensor scores_scale = softmax.template max_get_scale(tSrS); - softmax.template online_softmax(tSrS); - if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); - Tensor tOrP = make_tensor_like(tOrP_acc); - convert_type_out(tOrP_acc, tOrP); - if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (kStages > 1) { sync(); } - Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{})); - flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - if constexpr (kStages > 1) { load_K_next(); } - smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; - }; - - auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); - --n_block; - if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking - auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); - #pragma unroll 1 - for (; n_block >= n_block_min_causal_local_mask; --n_block) { - fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); - } - } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); - auto no_mask_fn = [](auto& tSrS, int n_block) { }; - #pragma unroll 1 - for (; n_block >= n_block_min_before_local_mask; --n_block) { - fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); - } - // Separate masking iterations on the left for local attention - if constexpr (Is_local) { - auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - #pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - } - } - float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; - Tensor scores_scale = softmax.finalize(v_descale); - softmax.rescale_o(tOrO, scores_scale); - if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); } - return true; - } - - template - CUTLASS_DEVICE bool - store_kv_new(Params const& params, - int const thread_idx, - SharedStorage &shared_storage, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord - ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - int const n_block_new_min = get<0>(n_block_new_min_max); - int const n_block_new_max = get<1>(n_block_new_min_max); - if (n_block_new_max <= n_block_new_min) { return false; } - - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); - - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; - - bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; - Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - - bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - - Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; - Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const seqlen_k_new = seqlen_info.seqlen_k_new; - using Rotary_t = Rotary; - Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, - params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, - seqlen_info.seqlen_rotary); - - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; - PagedKVManager_t paged_kv_manager( - params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, - params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, - params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, - bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, - // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position - 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ - ); - - static_assert(std::is_same_v); - static_assert(!PagedKV || std::is_same_v); - GmemTiledCopyQKV gmem_tiled_copy_kv_g2s; - auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx); - auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{}); // Only for index calculation - GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g; - auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx); - auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{}); // Only for index calculation - Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew); - Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK); - Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK); - Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV); - Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK); - Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK); - Tensor tKpKg2s = make_tensor(make_shape(size<2>(tKsKg2s))); - Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK); - Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK); - Tensor tKpKs2g = make_tensor(make_shape(size<2>(tKsKs2g))); - #pragma unroll - for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); } - #pragma unroll - for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); } - - auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { - static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; - Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write); - int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN - ? seqlen_k_new - n_block * kBlockN - : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); - // We don't need to clear the sK smem tiles since we won't write them out - flash::copy( - gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); - }; - - auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { - static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; - Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write); - int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN - ? seqlen_k_new - n_block * kBlockN - : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); - // We don't need to clear the sV smem tiles since we won't write them out - flash::copy( - gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); - }; - - auto store_K = [&] (int const n_block, int const smem_pipe_read) { - int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); - if (get<1>(params.shape_rotary) <= 0) { - Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read); - if constexpr (!PagedKV) { - Tensor tKgK_cur = tKgK(_, _, _, n_block); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) - ); - } else { - paged_kv_manager.store_K(n_block, tKsK_cur); - } - } else { - Tensor gK_cur = gK(_, _, n_block); - auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); - if (params.is_rotary_interleaved) { - auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); - rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block); - } else { - auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); - rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); - } - } - }; - - auto store_V = [&] (int const n_block, int const smem_pipe_read) { - int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); - Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read); - if constexpr (!PagedKV) { - Tensor tVgV_cur = tVgV(_, _, _, n_block); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit); - } else { - paged_kv_manager.store_V(n_block, tVsV_cur); - } - }; - - int n_block = n_block_new_max - 1; - // Note, using the for_each() function here to ensure `stage` is of type Int. - for_each(make_int_sequence{}, [&] (auto stage) { - static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; - static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; - if (Is_first_stage || n_block - stage >= n_block_new_min) { - load_K_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v - if constexpr (Is_first_stage) { __syncthreads(); } - if constexpr (!Is_last_stage) { - if (Is_first_stage || n_block - stage >= n_block_new_min) { - load_V_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - } - }); - - int smem_pipe_read = 0, smem_pipe_write = kStages - 1; - #pragma unroll 1 - for (; n_block >= n_block_new_min; --n_block) { - if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } - flash::cp_async_wait(); - __syncthreads(); - store_K(n_block, kStages > 1 ? smem_pipe_read : 0); - if (n_block - kStages + 1 >= n_block_new_min) { - load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; - flash::cp_async_wait(); - __syncthreads(); - store_V(n_block, kStages > 1 ? smem_pipe_read : 0); - smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; - if (n_block - kStages >= n_block_new_min) { - load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); - } - cute::cp_async_fence(); - } - - return true; - - } - -}; - -} // namespace flash diff --git a/flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp b/flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp deleted file mode 100644 index 0bdd4191538641a5c373e7f4b8dc9f52f95cc37a..0000000000000000000000000000000000000000 --- a/flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ /dev/null @@ -1,1759 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include "cutlass/pipeline/pipeline.hpp" - -#include "cute/tensor.hpp" - -#include "cutlass/gemm/collective/builders/sm90_common.inl" - -#include "named_barrier.hpp" -#include "seqlen.h" -#include "block.h" -#include "mask.h" -#include "pack_gqa.h" -#include "paged_kv.h" -#include "rotary.h" -#include "utils.h" -#include "sm90_pipeline_no_cluster.hpp" - -namespace flash { - -using namespace cute; - -template -struct CollectiveMainloopFwdSm90 { - - static constexpr int kStages = Stages; - using ClusterShape = ClusterShape_; - using TileShape_MNK = TileShape_MNK_; - using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; - using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; - using Element = Element_; - using ElementAccum = ElementAccum_; - using ElementSAux = ElementSAux_; - using ArchTag = ArchTag_; - static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; - static constexpr bool Is_causal = Is_causal_; - static constexpr bool Is_local = Is_local_; - static constexpr bool Has_softcap = Has_softcap_; - static constexpr bool Varlen = Varlen_; - static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; - static constexpr bool AppendKV = AppendKV_; - static constexpr bool HasQv = HasQv_; - static constexpr bool PackGQA = PackGQA_; - static constexpr bool Split = Split_; - static constexpr bool V_colmajor = V_colmajor_; - static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; - static constexpr bool Use_TMA_Q = !PackGQA; - static constexpr bool Use_TMA_KV = !PagedKVNonTMA; - static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); - static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); - static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; - static constexpr bool LargeHeadDimV = kHeadDimV > 256; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; - static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; - using BlockMN_t = flash::BlockMN; - - static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); - static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); - static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); - - // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. - // Leaving this option here for reference. - static constexpr bool MmaQK_is_RS = false; - // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); - static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); - - // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write - static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; - - using AtomLayoutQK = Layout, _1, _1>>; - using TiledMmaQK = decltype(cute::make_tiled_mma( - std::conditional_t< - !MmaQK_is_RS, - decltype(cute::GMMA::ss_op_selector()), - decltype(cute::GMMA::rs_op_selector()) - >{}, - AtomLayoutQK{})); - using AtomLayoutPV = std::conditional_t< - !LargeHeadDimV, - AtomLayoutQK, - Layout, _1>> - >; - using TiledMmaPV = decltype(cute::make_tiled_mma( - std::conditional_t< - !MmaPV_is_RS, - decltype(cute::GMMA::ss_op_selector()), - decltype(cute::GMMA::rs_op_selector()) - >{}, - AtomLayoutPV{})); - using TiledMmaQV = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutQK{})); - // For hdim64,512, WG1 can use RS but WG2 must use SS - using TiledMmaPV_RS = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(), - AtomLayoutPV{})); - - static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); - static constexpr int NumMmaThreads = size(TiledMmaPV{}); - static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; - static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); - static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); - static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; - static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); - - using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); - - using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutK = decltype(tile_to_shape( - SmemLayoutAtomK{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); - using SmemLayoutVt = decltype(tile_to_shape( - SmemLayoutAtomVt{}, - make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), - std::conditional_t, cute::Step<_2, _1, _3>>{})); - - using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); - using SmemLayoutVtMma = decltype(tile_to_shape( - SmemLayoutAtomVtMma{}, - make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), - std::conditional_t, cute::Step<_2, _1, _3>>{})); - - using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); - using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); - using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); - using SmemLayoutVMmaQV = decltype(tile_to_shape( - SmemLayoutAtomVMmaQV{}, - make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); - static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); - - // Only used if we're using cp.async to load V - using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int>()); - using SmemLayoutVCpAsync = decltype(tile_to_shape( - SmemLayoutAtomVCpAsync{}, - make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); - - using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); - using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); - - // Only for LargeHeadDimV where WG0 sends WG1 the scales - using SmemLayoutScale = cute::Layout, Int>>; - - using SmemCopyAtomP = Copy_Atom; - - // Hardcoded to be at most 64 query heads - using SmemLayoutSAux = Layout>; - - // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. - // For FP16/BF16 we don't do any transposing. - static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); - static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; - // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), - // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). - static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); - using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; - using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; - using LDSM_value_shape = Shape<_2, _2, _1, _4>; - using LDSM_value_stride = Stride<_1, _2, _16, _4>; - using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; - using S2RTiledCopyVt = decltype(make_tiled_copy( - Copy_Atom{}, Layout{}, - Layout{})); - - using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; - using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; - using STSM_value_shape = Shape<_1, _4, _2, _2>; - using STSM_value_stride = Stride<_0, _1, _4, _8>; - using STSM_divide_shape = Shape<_8, _16>; - // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts - // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). - // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. - // using STSM_value_shape = Shape<_2, _4, _1, _2>; - // using STSM_value_stride = Stride<_4, _1, _0, _8>; - // using STSM_divide_shape = Shape<_16, _16>; - using R2STiledCopyV = decltype(make_tiled_copy( - Copy_Atom{}, Layout{}, - Layout{})); - - using GmemTiledCopyQ = cute::SM90_TMA_LOAD; - using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); - - // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there - static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each - // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. - // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved - // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will - // load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where - // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. - static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication - static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow"); - using GmemTiledCopyAppendKV = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - - using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) - using StrideQK = cute::Stride; - using StrideV = std::conditional_t>; - // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) - using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; - using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; - using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) - using StridePageTable = cute::Stride; - using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) - using StrideRotary = cute::Stride; - using StrideDescale = cute::Stride; - - using TMA_Q = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - SmemLayoutQ{}, - TileShape_MNK{}, - ClusterShape{})); - - using TMA_K = decltype(make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{})); // mcast along M mode for this N load, if any - - using TMA_V = decltype(make_tma_copy( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - - using TMA_Qv_ = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - SmemLayoutQv{}, - TileShape_MNK_QV{}, - ClusterShape{})); - using TMA_Qv = std::conditional_t; - - // Set the bytes transferred in this TMA transaction (may involve multiple issues) - static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); - - using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; - using MainloopPipelineK = std::conditional_t>; - using MainloopPipelineV = std::conditional_t>; - using MainloopPipelineVt = std::conditional_t>; - // We always use TMA for K_new and V_new - using MainloopPipelineKVNew = PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - - // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned - // and have sQ being position_independent_swizzle_tensor. - // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); - static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); - static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); - static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); - static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); - static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); - static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); - - using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; - using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; - using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; - // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes - // smem size to go from 227KB to 228KB and we get "invalid argument". - - struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { - cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; - cute::array_aligned, SmemAlignmentQ> smem_q; - cute::array_aligned, SmemAlignmentK> smem_k; - SmemQv_t smem_qv; - cute::array_aligned, 128> smem_s_aux; - }; - - struct TensorStorageWithPNoTranspose : cute::aligned_struct { - cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; - cute::array_aligned, SmemAlignmentQ> smem_q; - cute::array_aligned, SmemAlignmentK> smem_k; - SmemQv_t smem_qv; - SmemP_t smem_p; - cute::array_aligned, 128> smem_s_aux; - }; - struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { - cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; - cute::array_aligned, SmemAlignmentQ> smem_q; - cute::array_aligned, SmemAlignmentK> smem_k; - SmemQv_t smem_qv; - SmemP_t smem_p; - SmemScale_t smem_scale; - cute::array_aligned, 128> smem_s_aux; - }; - - using TensorStorageNoTranspose = std::conditional_t< - MmaPV_is_RS, - TensorStorageWithoutPNoTranspose, - std::conditional_t - >; - - static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); - static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); - static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); - struct TensorStorageTransposeV : cute::aligned_struct { - cute::array_aligned, SmemAlignmentV> smem_v; - cute::array_aligned, SmemAlignmentVt> smem_vt; - cute::array_aligned, SmemAlignmentQ> smem_q; - cute::array_aligned, SmemAlignmentK> smem_k; - SmemQv_t smem_qv; - SmemScale_t smem_scale; - cute::array_aligned, 128> smem_s_aux; - }; - - using TensorStorage = std::conditional_t; - - // These are tuned for speed. They don't affect correctness. - static constexpr bool UseSchedulerBarrier = (IntraWGOverlap - ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) - : NumMmaWarpGroups == 2) - && !LargeHeadDimV; - static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; - - // Host side kernel arguments - struct Arguments { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQK const stride_Q; - Element* const ptr_K; // not Element const* since we might append to KV cache in-place - ShapeQKV const shape_K; - StrideQK const stride_K; - Element* const ptr_V; - int32_t const headdim_v; - StrideV const stride_V; - Element const* const ptr_K_new; - ShapeQKV const shape_K_new; - StrideQK const stride_K_new; - Element const* const ptr_V_new; - StrideV const stride_V_new; - Element const* const ptr_Qv; - StrideQK const stride_Qv; - Element const* const ptr_rotary_cos; - ShapeRotary const shape_rotary; - StrideRotary const stride_rotary_cos; - Element const* const ptr_rotary_sin; - StrideRotary const stride_rotary_sin; - bool const is_rotary_interleaved; - int const* const ptr_pagetable; - ShapePageTable const shape_pagetable; - StridePageTable const stride_pagetable; - float const softmax_scale; - float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; - StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1; - float const softcap_val; - int const num_splits; - int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; - ElementSAux const* const ptr_S_aux = nullptr; - }; - - // Device side kernel params - struct Params { - Element const* const ptr_Q; - ShapeQKV const shape_Q; - StrideQK const stride_Q; - ShapeQPacked const shape_Q_packed; - StrideQPacked const stride_Q_packed; - Element* const ptr_K; - ShapeQKV const shape_K; - StrideQK const stride_K; - Element* const ptr_V; - int32_t const headdim_v; - StrideV const stride_V; - Element const* const ptr_K_new; - ShapeQKV const shape_K_new; - StrideQK const stride_K_new; - Element const* const ptr_V_new; - StrideV const stride_V_new; - Element const* const ptr_Qv; - StrideV const stride_Qv; - ShapeQPacked const shape_Qv_packed; - StrideQPacked const stride_Qv_packed; - Element const* const ptr_rotary_cos; - ShapeRotary const shape_rotary; - StrideRotary const stride_rotary_cos; - Element const* const ptr_rotary_sin; - StrideRotary const stride_rotary_sin; - bool const is_rotary_interleaved; - int const* const ptr_pagetable; - ShapePageTable const shape_pagetable; - StridePageTable const stride_pagetable; - cutlass::FastDivmod page_size_divmod; - cutlass::FastDivmod blockN_per_page_size_divmod; - cutlass::FastDivmod qhead_per_khead_divmod; - TMA_Q tma_load_Q; - TMA_K tma_load_K; - TMA_V tma_load_V; - TMA_K tma_load_K_new; - TMA_V tma_load_V_new; - TMA_Qv tma_load_Qv; - float const softmax_scale_log2; - float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; - StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - float const softcap_val; - int const window_size_left, window_size_right; - int const num_splits; - int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; - ElementSAux const* const ptr_S_aux = nullptr; - }; - - static Params - to_underlying_arguments(Arguments const& args) { - Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); - TMA_Q tma_load_Q = make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - mQ, - SmemLayoutQ{}, - TileShape_MNK{}, - ClusterShape{}); // no mcast for Q - Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); - TMA_K tma_load_K = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - mK, - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), - make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), - select<1, 0, 2, 3>(args.stride_V)); - TMA_V tma_load_V = make_tma_copy( - GmemTiledCopyKV{}, - mV, - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); - TMA_K tma_load_K_new = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - cute::conditional_return(mKnew, mK), - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), - make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), - select<1, 0, 2, 3>(args.stride_V_new)); - TMA_V tma_load_V_new = make_tma_copy( - GmemTiledCopyKV{}, - cute::conditional_return(mVnew, mV), - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); - Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); - TMA_Qv tma_load_Qv = [&] { - if constexpr (HasQv) { - return make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - mQv, - SmemLayoutQv{}, - TileShape_MNK_QV{}, - ClusterShape{}); // no mcast for Qv - } else { - return nullptr; - } - }(); - // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) - int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); - auto const shape_Q_packed = cute::conditional_return( - args.shape_Q, - make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) - ); - auto const stride_Q_packed = cute::conditional_return( - args.stride_Q, - make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) - ); - auto const shape_Qv_packed = cute::conditional_return( - shape_Qv, - make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) - ); - auto const stride_Qv_packed = cute::conditional_return( - args.stride_Qv, - make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) - ); - if (get<1>(args.shape_rotary) > 0) { - assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); - } - assert(args.num_splits >= 1); - int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); - if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { - assert(page_size % kBlockN == 0); - assert(!args.leftpad_k); - } - // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - // Right after this, we multiply by log2(e) before applying exp2. - // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) - // (assigning it to params.softmax_scale_log2). - return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, - args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, - args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, - args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, - args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, - args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, - cutlass::FastDivmod(page_size), // page_size_divmod - cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod - cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, - !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, - args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, - !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, - !Split ? 1 : args.num_splits, - args.kv_batch_idx, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.ptr_S_aux}; - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& params) { - if constexpr (Use_TMA_Q) { - cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); - if constexpr (HasQv) { - cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); - } - } - if constexpr (Use_TMA_KV) { - cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); - } - if constexpr (AppendKV) { - cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor()); - cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor()); - } - } - - template - CUTLASS_DEVICE void - load(Params const& params, - MainloopPipelineK pipeline_k, - MainloopPipelineV pipeline_v, - MainloopPipelineVt pipeline_vt, - PipelineState& smem_pipe_write, - SharedStorage &shared_storage, - SchedulerPrefetch const& scheduler_prefetch, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - int &work_idx - ) { - - // some of these are captured in lambda so can't use structured binding - int const m_block = get<0>(block_coord); - int const bidh = get<1>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. - if constexpr (Is_causal || Is_local || Varlen || Split) { - if (n_block_max <= n_block_min) { - scheduler_prefetch(); - return; - } - } - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sK_pi = as_position_independent_swizzle_tensor(sK); - // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose. - // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes. - Tensor sVt = [&] { - if constexpr (!Transpose_V) { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); - } else { - return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{})); - } - }(); - // Only used if Transpose_V - Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{})); - // Only used if we're using cp.async to load V - Tensor sVcpasync = [&] { - if constexpr (!Transpose_V) { - return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); - } else { - return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); - } - }(); - Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); - - int const thread_idx = threadIdx.x % NumProducerThreads; - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; - - // Prepare the TMA loads - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - bool const is_varlen_q = Varlen && params.cu_seqlens_q; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); - auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); - - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } - Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) - - auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); - Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) - Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) - // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually - auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); - Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) - Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) - auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); - Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) - Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) - auto [tQvgQv, tQvsQv] = [&] { - if constexpr (HasQv) { - auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); - Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) - auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); - Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) - Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) - return cute::make_tuple(tQvgQv, tQvsQv); - } else { - return cute::make_tuple(nullptr, nullptr); - } - }(); - - // This is used to index into the batch dimension of mK and mV - int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; - - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; - PagedKVManager_t paged_kv_manager( - params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, - params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, params.blockN_per_page_size_divmod, - bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx - ); - - // Set up for transposing V, only used if Transpose_V - S2RTiledCopyVt s2r_tiled_copy_vt; - R2STiledCopyV r2s_tiled_copy_v; - auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(thread_idx); - auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(thread_idx); - // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8, kStages) - Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages) - // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages) - Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages) - CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); - // Faster to have 2 LDSM.T, byte permute, STSM for better ILP - static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; - Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) - Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) - auto transpose_V = [&](int stage) { - if constexpr (Transpose_V) { - #pragma unroll - for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { - Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{})); - static_assert(size<0>(tTransrV) == 16); - Tensor tTransrV_64 = recast(tTransrV); - cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV); - #pragma unroll - for (int j = 0; j < size(tTransrV_64); ++j) { - uint32_t upper = tTransrV_64[j].x; - uint32_t lower = tTransrV_64[j].y; - tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); - tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); - } - cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage)); - } - } - }; - - uint16_t mcast_mask_kv = 0; - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); - } - } - - auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { - pipeline_k.producer_acquire(smem_pipe_write); - if constexpr (!PagedKVNonTMA) { - auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); - copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); - } else { - constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; - paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); - pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); - } - }; - - auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { - auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); - pipeline_v_load.producer_acquire(smem_pipe_write); - if constexpr (!PagedKVNonTMA) { - auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); - copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); - } else { - constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; - paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); - pipeline_v_load.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); - } - }; - - auto copy_Vt_to_V = [&] (auto const& smem_pipe_write) { - // Instead of maintaining smem_pipe_read as a separate variable, we can just use smem_pipe_write, - // and exploit the invariance that smem_pipe_write.phase() == smem_pipe_read.phase() ^ 1. - // This saves 1 or 2 registers. - PipelineState smem_pipe_read{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()}; - pipeline_vt.consumer_wait(smem_pipe_read); - pipeline_v.producer_acquire(smem_pipe_write); - transpose_V(smem_pipe_write.index()); - // SMEM fence to make sure V is transposed before math - cutlass::arch::fence_view_async_shared(); - pipeline_v.producer_commit(smem_pipe_write); - // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized - // before calling. Without this we get race conditions. - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/); - pipeline_vt.consumer_release(smem_pipe_read); - }; - - int n_block = n_block_max - 1; - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - // If this is true, we're guaranteed that only the first warp will execute this function - static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; - bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); - - if (should_load_KV) { - if constexpr (PagedKVNonTMA) { - paged_kv_manager.template load_page_table(n_block); - } else { - paged_kv_manager.template load_page_table_TMA(n_block); - } - if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } - // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());} - load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); - // if (thread_idx == 0) { printf("Producer: main load, after load K, index = %d\n", smem_pipe_write.index());} - } - - if constexpr (Use_TMA_Q) { - // Wait for the MMA warpgroups to signal that smem_q is ready - if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - } - - if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { - shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), - tQgQ, tQsQ); - if constexpr (HasQv) { - shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); - copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), - tQvgQv, tQvsQv); - } - } - } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; - PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); - auto &barrier_Q = shared_storage.pipelines.barrier_Q; - cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); - barrier_Q.arrive(); - if constexpr (HasQv) { - Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); - using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; - PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); - auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; - cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); - barrier_Qv.arrive(); - } - } - - // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem - // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the - // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. - // if (thread_idx == 0) { printf("Producer: main load, before barrier_O, work_idx = %d\n", work_idx);} - shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); - // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");} - - if constexpr (!Transpose_V && !IntraWGOverlap) { - if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } - } - int n_block_prev = n_block; - --n_block; - #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1) - for (; n_block >= n_block_min; --n_block) { - PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind - ++smem_pipe_write; - if (should_load_KV) { - if constexpr (PagedKVNonTMA) { - paged_kv_manager.template load_page_table(n_block); - } else { - paged_kv_manager.load_page_table_TMA(n_block); - } - if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } - load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - if constexpr (!Transpose_V) { - if constexpr (IntraWGOverlap) { - load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); - } else { - load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - } - } - } - n_block_prev = n_block; - if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } - } - scheduler_prefetch(); - if constexpr (!Transpose_V && IntraWGOverlap) { - if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } - } - if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); } - ++smem_pipe_write; - // At the end, all threads have the correct smem_pipe_write. - ++work_idx; - } - - template - CUTLASS_DEVICE void - load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt, - PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) { - // If we don't wait for barrier_O here, when using Cluster, CTA0 might exit early and CTA1 will - // try to arrive on barrier_O of CTA0, causing "unspecified launch failure". - shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - // Issue the epilogue waits - // TODO: check if this should be called by 1 thread or more - if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was still inverted from make_producer_start_state - */ - pipeline_k.producer_tail(smem_pipe_write); - pipeline_v.producer_tail(smem_pipe_write); - if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); } - } - } - - CUTLASS_DEVICE void - warp_scheduler_barrier_sync() { - if constexpr (UseSchedulerBarrier) { - cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); - } - } - - CUTLASS_DEVICE void - warp_scheduler_barrier_arrive() { - if constexpr (UseSchedulerBarrier) { - static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); - int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1; - int const next_WG = NumMmaWarpGroups == 2 - ? 1 - cur_WG - : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0); - cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/); - } - } - - CUTLASS_DEVICE void - mma_init() { - int warp_group_idx = flash::canonical_warp_group_idx_nosync(); - // Tell producers that smem_q is ready - if (!LargeHeadDimV || warp_group_idx == 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - } - if (LargeHeadDimV && warp_group_idx > 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - if constexpr (UseSchedulerBarrier) { - // We have NamedBarrier for up to 3 WGs - static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); - // WG1 needs the very first signal to start - if (warp_group_idx == 1) { - cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); - } - } - } - - template - CUTLASS_DEVICE bool - mma(Params const& params, - MainloopPipelineK pipeline_k, - MainloopPipelineV pipeline_v, - PipelineState& smem_pipe_read, - FrgTensorO& tOrO, - Softmax& softmax, - int const thread_idx, - int &work_idx, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "O tensor must be rmem resident."); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda - int const m_block = get<0>(block_coord); - int const bidh = get<1>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier - if constexpr (Is_causal || Is_local || Varlen || Split) { - if (n_block_max <= n_block_min) { return false; } - } - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); - Tensor sP = [&] { - if constexpr (MmaPV_is_RS) { - // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); - } else { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); - } - }(); - Tensor sScale = [&] { - if constexpr (LargeHeadDimV) { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); - } else { // won't be used, just a placeholder - return make_tensor(make_smem_ptr(static_cast(nullptr)), SmemLayoutScale{}); - } - }(); - Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); - Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); - - if constexpr (!MmaQK_is_RS) { - static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and - stride<0>(typename TiledMmaQK::BLayout{}) == 0 and - size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); - } - static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(make_shape(Int{}), - make_stride(Int{})); - - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMmaQK tiled_mma_qk; - TiledMmaPV tiled_mma_pv; - TiledMmaQV tiled_mma_qv; - auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); - - auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); - auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); - - // Allocate "fragments/descriptors" - Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); - Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); - Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); - Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); - Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); - Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); - Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); - - // For storing scales to smem, only used when LargeHeadDimV - auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); - Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); - auto store_scales = [&](auto& scales, int stage) { - static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); - #pragma unroll - for (int mi = 0; mi < size(taccOcO_row); ++mi) { - if (get<1>(taccOcO_row(_0{})) == 0) { - sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); - } - } - }; - - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - }; - - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int n_block = n_block_max - 1; - - flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod - ); - - float softcap_val = params.softcap_val; - if constexpr (Has_softcap && Is_FP8) { - float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; - float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; - softcap_val *= q_descale * k_descale; - } - // Softcapping needs to happen before masking since if we apply after masking, softcapping - // can turn -inf to e.g. -50.0, which can affect the attention softmax. - auto scoremod_premask_fn = [&](auto& tSrS) { - if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } - }; - - auto write_P_to_smem = [&](auto& tOrP) { - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); - }; - - auto arrive_on_P_write_barrier = [&] { - cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - }; - - using TensorT = typename Softmax::TensorT; - using LayoutT = typename TensorT::layout_type; - auto finalize_dispatch = [&](TensorT& scores_scale, float const v_descale) { - if (params.ptr_S_aux && (!Split || (split_idx & 0x0000FFFF) == 0)) { - Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{}); - Tensor tSrS_aux = make_tensor_like(scores_scale); - static_assert(is_static::value); - static_assert(size(tSrS_aux) == size(LayoutT{})); - if constexpr(!PackGQA) { - #pragma unroll - for(int mi = 0; mi < size(tSrS_aux); ++mi) { - tSrS_aux(mi) = static_cast(sS_aux(bidh)); - } - } else { - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); - auto thread_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); - Tensor tScS = thread_mma_qk.partition_C(cS); - Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); - static_assert(size<0>(tScS_rowcol) == size(tSrS_aux)); - int const qhead_per_khead = params.qhead_per_khead_divmod.divisor; - #pragma unroll - for(int mi = 0; mi < size(tSrS_aux); ++mi) { - int row = m_block * kBlockM + get<0>(tScS_rowcol(mi, _0{})); - int bidh_mi = (row % qhead_per_khead) + bidh_kv * qhead_per_khead; - tSrS_aux(mi) = static_cast(sS_aux(bidh_mi)); - } - } - cute::copy(softmax.finalize_aux(tSrS_aux, v_descale), scores_scale); - } else { - cute::copy(softmax.finalize(v_descale), scores_scale); - } - }; - - auto &barrier_Q = shared_storage.pipelines.barrier_Q; - if constexpr (!AppendKV) { - barrier_Q.wait(work_idx % 2); - } else { - if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - using Rotary_t = Rotary; - Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, - params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, - seqlen_info.seqlen_rotary); - Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); - int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; - if (params.is_rotary_interleaved) { - auto [tRrCos, tRrSin] = cute::conditional_return( - rotary.template load_cos_sin(m_block), - rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) - ); - barrier_Q.wait(work_idx % 2); - rotary.apply_Q_interleaved(sQ_pi, tRrCos, tRrSin, m_block, qhead_per_khead); - } else { - auto [tRrCosCont, tRrSinCont] = cute::conditional_return( - rotary.template load_cos_sin(m_block), - rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) - ); - barrier_Q.wait(work_idx % 2); - rotary.apply_Q_contiguous(sQ_pi, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); - } - // SMEM fence to make sure the rotated Q is visible to GMMA - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); - } else { - barrier_Q.wait(work_idx % 2); - } - } - - if constexpr (MmaQK_is_RS) { - using SmemCopyAtomQ = Copy_Atom; - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); - Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); - cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); - } - - if constexpr (IntraWGOverlap) { - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read); - if constexpr (HasQv) { - shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); - consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); - } - scoremod_premask_fn(tSrS); - mask.template apply(tSrS, m_block, n_block); - - Tensor scores_scale = softmax.template max_get_scale(tSrS); - // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f - - softmax.template online_softmax(tSrS); - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); - Tensor tOrP = make_tensor_like(tOrP_acc); - convert_type_out(tOrP_acc, tOrP); - if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } - if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } - --n_block; - - // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - clear(tOrO); - // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; - - // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. - auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { - static constexpr bool Check_inf = decltype(check_inf_type)::value; - PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); - ++smem_pipe_read; - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } - warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr(!HasQv) { - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } - } - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - warp_scheduler_barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read); // release K - if constexpr (HasQv) { - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); - } - scoremod_premask_fn(tSrS); - mask_fn(tSrS, n_block); - cute::copy(softmax.template max_get_scale(tSrS), scores_scale); - if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } - softmax.template online_softmax(tSrS); - if constexpr (!HasQv) { - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - } - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); - if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } - if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } - }; - - if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking - auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); - #pragma unroll 1 - for (; n_block >= n_block_min_causal_local_mask; --n_block) { - fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); - } - } - - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); - auto no_mask_fn = [](auto& tSrS, int n_block) { }; - #pragma unroll 1 - for (; n_block >= n_block_min_before_local_mask; --n_block) { - fwd_step(n_block, no_mask_fn, cute::false_type{} /*check_inf*/); - } - // Separate masking iterations on the left for local attention - if constexpr (Is_local) { - auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - #pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); - } - } - // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); - float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; - // cute::copy(softmax.finalize(v_descale), scores_scale); - finalize_dispatch(scores_scale, v_descale); - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - store_scales(scores_scale, smem_pipe_read.index()); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang - softmax.rescale_o(tOrO, scores_scale); - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } - ++smem_pipe_read; - - } else { // No intra-WG overlap - - warp_scheduler_barrier_sync(); - - auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { - static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; - static constexpr bool Check_inf = decltype(check_inf_type)::value; - auto smem_pipe_read_prev = smem_pipe_read; - if constexpr (!Is_first_iter) { ++smem_pipe_read; } - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - if constexpr (!HasQv) { - warp_scheduler_barrier_arrive(); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read); // release K - } else { - if constexpr (Is_first_iter) { - shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); - } - consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); - warp_scheduler_barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read); // release K - warpgroup_wait<0>(); - } - scoremod_premask_fn(tSrS); - mask_fn(tSrS, n_block); - Tensor scores_scale = softmax.template max_get_scale(tSrS); - if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); } - softmax.template online_softmax(tSrS); - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); - Tensor tOrP = make_tensor_like(tOrP_acc); - convert_type_out(tOrP_acc, tOrP); - if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } - if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } - if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } - warp_scheduler_barrier_sync(); - if constexpr (!MmaPV_use_RS_WG1) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } else { - TiledMmaPV_RS tiled_mma_pv_rs; - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } - if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read); // release V - }; - - auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); - --n_block; - if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking - auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); - #pragma unroll 1 - for (; n_block >= n_block_min_causal_local_mask; --n_block) { - fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); - } - } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); - auto no_mask_fn = [](auto& tSrS, int n_block) { }; - #pragma unroll 1 - for (; n_block >= n_block_min_before_local_mask; --n_block) { - fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); - } - // Separate masking iterations on the left for local attention - if constexpr (Is_local) { - auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - #pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - } - } - warp_scheduler_barrier_arrive(); - // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; - // Tensor scores_scale = softmax.finalize(v_descale); - Tensor scores_scale = make_tensor_like(softmax.row_max); - finalize_dispatch(scores_scale, v_descale); - - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - store_scales(scores_scale, smem_pipe_read.index()); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - softmax.rescale_o(tOrO, scores_scale); - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } - ++smem_pipe_read; - } - ++work_idx; - return true; - } - - template - CUTLASS_DEVICE bool - mma_pv(Params const& params, - MainloopPipelineV pipeline_v, - PipelineState& smem_pipe_read, - FrgTensorO& tOrO, - Softmax& softmax, - int const thread_idx, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "O tensor must be rmem resident."); - // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda - int const m_block = get<0>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier - if constexpr (Is_causal || Is_local || Varlen || Split) { - if (n_block_max <= n_block_min) { return false; } - } - - Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); - Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); - Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); - static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(make_shape(Int{}), - make_stride(Int{})); - - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMmaPV tiled_mma_pv; - auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); - - // Allocate "fragments/descriptors" - Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); - Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); - - // For load scales to smem, pretend thread_idx is thread_idx % 128 - auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); - Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); - Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); - auto load_scales = [&](auto& scales, int stage) { - static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); - #pragma unroll - for (int mi = 0; mi < size(taccOcO_row); ++mi) { - scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); - } - }; - - // clear(tOrO); - // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; - - typename Softmax::TensorT scores_scale; - - int n_block = n_block_max - 1; - // If HasQv, then by the time P is ready, V must have been ready as well - if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - pipeline_v.consumer_release(smem_pipe_read); // release V - --n_block; - - #pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - load_scales(scores_scale, smem_pipe_read.index()); - softmax.rescale_o(tOrO, scores_scale); - ++smem_pipe_read; - if constexpr (!HasQv) { - auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); - pipeline_v.consumer_wait(smem_pipe_read, barrier_token); - } - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - pipeline_v.consumer_release(smem_pipe_read); // release V - }; - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - load_scales(scores_scale, smem_pipe_read.index()); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - softmax.rescale_o(tOrO, scores_scale); - if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } - ++smem_pipe_read; - return true; - } - - template - CUTLASS_DEVICE bool - load_kv_new(Params const& params, - MainloopPipelineKVNew pipeline_k_new, - MainloopPipelineKVNew pipeline_v_new, - PipelineState& smem_pipe_write, - SharedStorage &shared_storage, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - int const work_idx - ) { - - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - - if (n_block_new_max <= n_block_new_min) { return false; } - - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sVt = [&] { - if constexpr (!Transpose_V) { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); - } else { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); - } - }(); - - // int const thread_idx = threadIdx.x % NumProducerThreads; - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - - // Prepare the TMA loads - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; - Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - - Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) - - auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); - Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) - Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) - auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); - Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) - Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) - - uint16_t mcast_mask_kv = 0; - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); - } - } - - auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { - pipeline_k_new.producer_acquire(smem_pipe_write); - copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), - tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); - }; - - auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { - pipeline_v_new.producer_acquire(smem_pipe_write); - copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), - tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); - }; - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - // If this is true, we're guaranteed that only the first warp will execute this function - static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; - bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); - - int n_block = n_block_new_max - 1; - // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV - // and the main attention are not the same. We want to make sure the consumers - // have finished reading all smem_k and smem_v for the previous iteration. - shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); - if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } - // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } - // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - ++smem_pipe_write; - --n_block; - // if (thread_idx == 0) { printf("Producer: before for loop\n"); } - #pragma unroll 1 - for (; n_block >= n_block_new_min; --n_block) { - if (should_load_KV) { - load_K_new(n_block, smem_pipe_write); - // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - load_V_new(n_block, smem_pipe_write); - // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - } - ++smem_pipe_write; - } - // if (thread_idx == 0) { printf("Producer: after for loop\n"); } - // At the end, all threads have the correct smem_pipe_write. - return true; - } - - template - CUTLASS_DEVICE bool - store_kv_new(Params const& params, - MainloopPipelineKVNew pipeline_k_new, - MainloopPipelineKVNew pipeline_v_new, - PipelineState& smem_pipe_read, - int const thread_idx, - SharedStorage &shared_storage, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord - ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - if (n_block_new_max <= n_block_new_min) { return false; } - - // as_position_independent_swizzle_tensor makes address calculation easier - Tensor sK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{})); - // We want to use SmemLayoutVCpAsync to have shape (kBlockN, kHeadDim) instead of (kHeadDim, kBlockN) - Tensor sV = [&] { - if constexpr (!Transpose_V) { - return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); - } else { - return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); - } - }(); - - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; - - bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - - int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; - Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) - - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const seqlen_k_new = seqlen_info.seqlen_k_new; - using Rotary_t = Rotary; - Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, - params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, - seqlen_info.seqlen_rotary); - - // This is used to index into the batch dimension of mK and mV - int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; - - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; - PagedKVManager_t paged_kv_manager( - params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, - params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, params.blockN_per_page_size_divmod, - bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx - // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position - ); - - if constexpr (UseSchedulerBarrier) { - // WG1 already got the very first signal from mma_init(), but we'll be using the same NamedBarrier. - // So we'll need to "cancel it out" here and then re-signal it at the end. - if (flash::canonical_warp_group_idx_nosync() == 1) { - cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); - } - } - - static_assert(std::is_same_v); - static_assert(!PagedKVNonTMA || std::is_same_v); - GmemTiledCopyAppendKV gmem_tiled_copy_kv; - auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); - Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tKgK = gmem_thr_copy_kv.partition_D(gK); - Tensor tVsV = gmem_thr_copy_kv.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tVgV = gmem_thr_copy_kv.partition_D(gV); - Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcK = gmem_thr_copy_kv.partition_D(cK); - Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); - #pragma unroll - for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } - Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) - Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); - Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); - #pragma unroll - for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } - Tensor tVpV = cute::conditional_return(tKpK, tVpV_); - - auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { - int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); - if (get<1>(params.shape_rotary) <= 0) { - pipeline_k_new.consumer_wait(smem_pipe_read); - Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKVNonTMA) { - Tensor tKgK_cur = tKgK(_, _, _, n_block); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_kv, tKsK_cur, tKgK_cur, tKcK, tKpK, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) - ); - } else { - paged_kv_manager.store_K(n_block, tKsK_cur); - } - } else { - Tensor gK_cur = gK(_, _, n_block); - auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); - if (params.is_rotary_interleaved) { - auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); - pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); - } else { - auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); - pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); - } - } - // Without this fence I'm getting race condition when seqlen_k is large - cutlass::arch::fence_view_async_shared(); - // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized - // before calling. - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); - pipeline_k_new.consumer_release(smem_pipe_read); - // if (thread_idx == 0) { print_tensor(tKpK); printf("\n"); printf("seqlen_limit = %d\n", seqlen_k_new - n_block * kBlockN);} - }; - - auto store_V = [&] (int const n_block, auto const& smem_pipe_read) { - pipeline_v_new.consumer_wait(smem_pipe_read); - int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); - Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKVNonTMA) { - Tensor tVgV_cur = tVgV(_, _, _, n_block); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); - } else { - paged_kv_manager.store_V(n_block, tVsV_cur); - } - cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); - pipeline_v_new.consumer_release(smem_pipe_read); - }; - - #pragma unroll 1 - for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { - if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } - store_K(n_block, smem_pipe_read); - // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - store_V(n_block, smem_pipe_read); - // if (thread_idx == 0) { printf("Done storing V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - ++smem_pipe_read; - } - // if (thread_idx == 0) { printf("After for loop\n"); } - - // Re-signaling the NamedBarrier that we "canceled out" - if constexpr (UseSchedulerBarrier) { - if (flash::canonical_warp_group_idx_nosync() == 1) { - cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); - } - } - - return true; - - } - -}; - -} // namespace flash diff --git a/flash-attn/mask.h b/flash-attn/mask.h deleted file mode 100644 index 02d046268cf8a23f3ad36e87f03b2d3c7fdf4099..0000000000000000000000000000000000000000 --- a/flash-attn/mask.h +++ /dev/null @@ -1,157 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include "cutlass/fast_math.h" // For cutlass::FastDivmod - -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct Mask { - - static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); - - int const thread_idx; - int const seqlen_q, seqlen_k; - int const window_size_left, window_size_right, sink_token_length; - cutlass::FastDivmod const qhead_per_khead_divmod; - - CUTLASS_DEVICE - Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, - const int window_size_left, const int window_size_right, const int sink_token_length, - cutlass::FastDivmod const &qhead_per_khead_divmod) - : thread_idx(thread_idx) - , seqlen_q(seqlen_q) - , seqlen_k(seqlen_k) - , window_size_left(window_size_left) - , window_size_right(window_size_right) - , sink_token_length(sink_token_length) - , qhead_per_khead_divmod(qhead_per_khead_divmod) - { - }; - - template - CUTLASS_DEVICE - void apply(Tensor &tSrS, const int m_block, const int n_block) const { - static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); - static_assert(Layout::rank == 3, "Only support 3D Tensor"); - if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } - - auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); - auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); - - static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; - - Tensor cS = cute::make_identity_tensor(Shape, Int>{}); - Tensor tScS = thread_mma.partition_C(cS); - Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); - Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); - Tensor t0ScS = thread0_mma.partition_C(cS); - Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); - // We want to use the col indices of thread0 to compare, since that is known at compile time. - // So we subtract the limit by the first col index of this thread (get(tScS_rowcol(_0{}, _0{}))) - int const thread_col_offset = get(tScS_rowcol(_0{}, _0{})); - int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; - if constexpr (!Causal_mask && !Local_mask) { - if constexpr (Seqlenk_mask) { // Just masking based on col - #pragma unroll - for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { - if (int(get(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { - #pragma unroll - for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } - } - } - } - } else { // mask based on both row and col - if constexpr (!SwapAB) { - // If PackGQA, we split the work of compute divmod among threads in the same row - static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); - static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); - static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); - int mma_m_idx; - // Might get OOB but it's ok since we'll check it later - if constexpr (PackGQA) { - mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); - } - int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; - if constexpr (Causal_mask) { - #pragma unroll - for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { - int const row_idx = !PackGQA - ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM - : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); - int const col_limit_right = !Seqlenk_mask - ? row_idx + causal_row_offset - : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); - #pragma unroll - for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { - if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } - } - } - } else { - int const local_row_offset_right = causal_row_offset + window_size_right; - int const local_row_offset_left = causal_row_offset - 1 - window_size_left; - int const col_limit_sink = sink_token_length - n_block * kBlockN; - #pragma unroll - for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { - int const row_idx = !PackGQA - ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM - : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); - int const col_limit_right = !Seqlenk_mask - ? row_idx + local_row_offset_right - : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); - int const col_limit_left = row_idx + local_row_offset_left; - #pragma unroll - for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { - int const col_idx = int(get(t0ScS_rowcol(m, n))); - if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } - } - } - } - } else { - int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); - int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; - if constexpr (Causal_mask) { - #pragma unroll - for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { - int const col0 = int(get(t0ScS_rowcol(_0{}, n))); - // If col0 is beyond the column limit, we want to mask out the entire column, by setting - // row limit to be kBlockM. - int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; - #pragma unroll - for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { - if (int(get(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } - } - } - } else { - int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; - #pragma unroll - for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { - int const col0 = int(get(t0ScS_rowcol(_0{}, n))); - // If col0 is beyond the column limit, we want to mask out the entire column, by setting - // row limit to be kBlockM. - int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; - int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; - #pragma unroll - for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { - int const row_idx = int(get(t0ScS_rowcol(m, _0{}))); - if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } - } - } - } - } - } - }; - -}; - -} // namespace flash diff --git a/flash-attn/named_barrier.hpp b/flash-attn/named_barrier.hpp deleted file mode 100644 index a7dfb6439a234baffd8ffd154ab765c88021498c..0000000000000000000000000000000000000000 --- a/flash-attn/named_barrier.hpp +++ /dev/null @@ -1,72 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cutlass/arch/barrier.h" - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work -// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80. - -CUTLASS_DEVICE -static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) { - static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); - uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; - asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); - cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); -} - -CUTLASS_DEVICE -static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { - uint32_t barrier_id = static_cast(reserved_named_barriers); - asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); - cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); -} - -CUTLASS_DEVICE -static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) { - static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); - uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; - cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); - asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -} - -CUTLASS_DEVICE -static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { - uint32_t barrier_id = static_cast(reserved_named_barriers); - cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); - asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -} - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Enumerates the reserved named barriers to avoid potential conflicts - -enum class FwdNamedBarriers { - QueryEmpty = 0, - WarpSchedulerWG1 = 1, - WarpSchedulerWG2 = 2, - WarpSchedulerWG3 = 3, - AppendKV = 4, - QueryRotated = 5, - PFull = 6, - PEmpty = 7, -}; - -enum class BwdNamedBarriers { - KVEmpty = 0, - PdS = 1, - dQEmptyWG1 = 2, - dQEmptyWG2 = 3, - dQEmptyWG3 = 4, - dQFullWG1 = 5, - dQFullWG2 = 6, - dQFullWG3 = 7, -}; - -} // flash diff --git a/flash-attn/pack_gqa.h b/flash-attn/pack_gqa.h deleted file mode 100644 index 160bf4306840bb55e72bfcc5ae6d10c99356e595..0000000000000000000000000000000000000000 --- a/flash-attn/pack_gqa.h +++ /dev/null @@ -1,255 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include "cutlass/fast_math.h" // For cutlass::FastDivmod - -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct PackGQAManager { - // We use CpAsync for Q, since TMA doesn't work there - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad; - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each - // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. - // In the case of PackGQA, this reduces the number of times we need to call divmod. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where - // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. - static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); - using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopyQCpAsync = decltype( - make_tiled_copy(GmemCopyAtomCpAsync{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per load - - // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need - // to sync within each WG, but didn't seem to be any faster. - // using GmemLayoutAtomWG = Layout, Int, Int >, - // Stride, _128, _1>>; - // using GmemTiledCopyQCpAsyncWG = decltype( - // make_tiled_copy(GmemCopyAtomCpAsync{}, - // GmemLayoutAtomNew{}, - // Layout>>{})); // Val layout, 8 or 16 vals per load - - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - - template - CUTLASS_DEVICE - static auto - compute_ptr(Tensor &tensor, TensorC const &tRows, - cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) { - // tensor of shape ((qhead_per_khead, seqlen_q)) - static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow); - using TensorType = typename Engine::value_type; - Tensor tPrPtr = make_tensor(Shape>{}); - #pragma unroll - for (int i = 0; i < NumPtrPerThread; ++i) { - int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow)); - int const idx = m_block * kBlockM + row; - int m_idx, h_idx; - m_idx = qhead_per_khead_divmod.divmod(h_idx, idx); - tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx))); - } - return tPrPtr; - } - - - template - CUTLASS_DEVICE - static void - load_Q(TensormQ const &mQ, // ((qhead_per_khead, seqlen_q), headdim) - TensorsQ &sQ, // (kBlockM, kHeadDim) - cutlass::FastDivmod const &qhead_per_khead_divmod, - int const thread_idx, int const seqlen_q, int const m_block - ) - { - GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async; - // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async; - auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_); - // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_); - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); } - - // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q. - // We split the work among threads loading the same row of Q, then __shfl_sync the pointers. - Tensor mQ_0 = mQ(_, _0{}); - Tensor tQcQ_row = tQcQ(_0{}, _, _0{}); - Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block); - int const qhead_per_khead = qhead_per_khead_divmod.divisor; - #pragma unroll - for (int m = 0; m < size<1>(tQsQ); ++m) { - int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{})); - Element const* q_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - if (idx < seqlen_q * qhead_per_khead) { - // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));} - Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape>{}); - Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape>{}); - #pragma unroll - for (int k = 0; k < size<2>(tQsQ); ++k) { - int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad; - // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false - // TODO: check this - cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k)); - } - } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows - } - }; - - template - CUTLASS_DEVICE - static void - store_LSE(TensormLSE &mLSE, // ((qhead_per_khead, seqlen_q)) - TensorsLSE const &tLSErLSE, // (kBlockM) split across threads according to tiled_mma - TiledMma tiled_mma, - cutlass::FastDivmod const &qhead_per_khead_divmod, - int const thread_idx, int const seqlen_o, int const m_block - ) - { - Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{}); - CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); // MMA_M - - // If PackGQA, we split the work of compute divmod among threads in the same row - static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); - static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); - static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow); - static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); - - Tensor tPrLSEPtr = compute_ptr(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); - static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1); - int const qhead_per_khead = qhead_per_khead_divmod.divisor; - #pragma unroll - for (int mi = 0; mi < size(tLSErLSE); ++mi) { - int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); - float* ptr_LSE_cur = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow)); - if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) { - *ptr_LSE_cur = tLSErLSE(mi); - } - } - }; - - template - CUTLASS_DEVICE - static void - store_O(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) - TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O - cutlass::FastDivmod const &qhead_per_khead_divmod, - int const thread_idx, int const seqlen_o, int const m_block - ) - { - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor cO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); } - - // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. - // We split the work among threads loading the same row of O, then __shfl_sync the pointers. - Tensor mO_0 = mO(_, _0{}); - Tensor tOcO_row = tOcO(_0{}, _, _0{}); - Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block); - int const qhead_per_khead = qhead_per_khead_divmod.divisor; - #pragma unroll - for (int m = 0; m < size<1>(tOrO); ++m) { - int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{})); - Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - if (idx < seqlen_o * qhead_per_khead) { - Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); - Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); - #pragma unroll - for (int k = 0; k < size<2>(tOrO); ++k) { - int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore; - if (tOpO(k)) { - cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki)); - } - } - } - } - }; - - template - CUTLASS_DEVICE - static void - store_O_direct(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) - TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to tiled_mma - TiledMma tiled_mma, - cutlass::FastDivmod const &qhead_per_khead_divmod, - int const thread_idx, int const seqlen_o, int const m_block - ) - { - static constexpr int kGmemElemsPerStoreDirect = 2; - cute::Copy_Atom, Element> gmem_copy_direct; - // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); - Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); - - Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); - Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); - Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); - - // If PackGQA, we split the work of compute divmod among threads in the same row - static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); - static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); - static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); - - // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. - // We split the work among threads loading the same row of O, then __shfl_sync the pointers. - Tensor mO_0 = mO(_, _0{}); - Tensor tPrOPtr = compute_ptr(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); - static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1); - - int const qhead_per_khead = qhead_per_khead_divmod.divisor; - #pragma unroll - for (int m = 0; m < size<1>(tOrO_copy); ++m) { - int row = m_block * kBlockM + get<0>(taccOcO_row(m)); - Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow)); - if (row < seqlen_o * qhead_per_khead) { - Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); - Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); - #pragma unroll - for (int k = 0; k < size<2>(tOrO_copy); ++k) { - int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)); - if (col < size<1>(mO)) { - cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect)); - } - } - } - } - }; - -}; - -} // namespace flash diff --git a/flash-attn/paged_kv.h b/flash-attn/paged_kv.h deleted file mode 100644 index 9ea59bcc2a2378bf0de1d626410146a93cc74420..0000000000000000000000000000000000000000 --- a/flash-attn/paged_kv.h +++ /dev/null @@ -1,354 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include "cutlass/fast_math.h" // For cutlass::FastDivmod - -#include "utils.h" - -namespace flash { - -using namespace cute; - -template -struct PagedKVManager { - // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), - // load_page_table(2), load_K(2), load_V(1), etc. - // So we need to compute the V pointers for the previous iteration. - - // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for - // rotary where we want each thread to have at least 2 loads per row. - - static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); - static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); - - // We use CpAsync for K and V if PagedKV, since TMA doesn't work there - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each - // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. - // In the case of PackGQA, this reduces the number of times we need to call divmod. - static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); - static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where - // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. - static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); - using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; - using GmemLayoutAtomKVCpAsync = Layout, Int>, - Stride, _1>>; - using GmemTiledCopyKVCpAsync = decltype( - make_tiled_copy(GmemCopyAtomCpAsync{}, - GmemLayoutAtomKVCpAsync{}, - Layout>>{})); // Val layout, 8 or 16 vals per load - using GmemTiledCopyKVStore = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtomKVCpAsync{}, - Layout>>{})); // Val layout, 8 or 16 vals per load - - using ShapeKV = cute::Shape; // (seqlen, d, head, batch) - using StrideKV = cute::Stride; - using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) - using StridePageTable = cute::Stride; - - using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _)); - using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _)); - using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); - using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); - using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); - using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); - using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); - - // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, - // since those require int64_t arithmetic. We optimize by having threads split this work. - // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows - // that each thread needs to load for the case of hdim 128 and kBlockN = 176. - // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. - // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. - static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); - static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); - using TensorPageOffset = decltype(make_tensor>(Shape>{})); - using TensorKVPtr = decltype(make_tensor(Shape>{})); - - GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; - cutlass::FastDivmod const &page_size_divmod; - cutlass::FastDivmod const &blockN_per_page_size_divmod; - int const thread_idx; - int const seqlen_k; - int const leftpad_k; - int const* const ptr_page_table; - GmemThrCopyKVCpAsync const gmem_thr_copy_kv; - TensorPageTable mPageTable; - TensorKV mK_paged, mV_paged; - TensortKpK tKpK; - TensortVpV tVpV; - TensorPageOffset tPrPageOffset; - TensorKVPtr tPrVPtr; - int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA - - CUTLASS_DEVICE - PagedKVManager(int const* const ptr_page_table_, - ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, - Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, - Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, - cutlass::FastDivmod const &page_size_divmod, - cutlass::FastDivmod const &blockN_per_page_size_divmod, - int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, - int bidb_kv_idx - ) - : page_size_divmod(page_size_divmod) - , blockN_per_page_size_divmod(blockN_per_page_size_divmod) - , thread_idx(thread_idx) - , seqlen_k(seqlen_k) - , leftpad_k(leftpad_k) - , ptr_page_table(ptr_page_table_) - , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) - , bidb_kv_idx(bidb_kv_idx) - , bidb_kv_idx_prev(bidb_kv_idx) - - { - mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); - mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); - auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); - mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); - tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - #pragma unroll - for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } - Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); - Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); - #pragma unroll - for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } - tVpV = cute::conditional_return(tKpK, tVpV_); - }; - - template - CUTLASS_DEVICE - void load_page_table(const int n_block) { - // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries - // it needs, and we don't need any sync between warps. - // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by - // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc. - #pragma unroll - for (int i = 0; i < kPageEntryPerThread; ++i) { - int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow); - int const row_idx = n_block * kBlockN + row; - int page_idx, page_offset; - page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k); - // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row - // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0. - int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0; - tPrPageOffset[i] = {page, page_offset}; - // if (cute::thread0()) { printf("row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\n", row, page_idx, page_offset, page, leftpad_k, seqlen_k); } - } - if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } - }; - - template - CUTLASS_DEVICE - void load_page_table_TMA(const int n_block) { - // We require that page size is a multiple of kBlockN, and there's no leftpad_k - if (ptr_page_table) { - bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; - } else { - n_block_idx = n_block; - } - if constexpr (First_iter && !KV_Same_Iter) { - bidb_kv_idx_prev = bidb_kv_idx; - n_block_idx_prev = n_block_idx; - } - }; - - CUTLASS_DEVICE - cute::tuple get_indices_for_K_TMA() { - return {n_block_idx, bidb_kv_idx}; - }; - - CUTLASS_DEVICE - cute::tuple get_indices_for_V_TMA() { - if constexpr (KV_Same_Iter) { - return {n_block_idx, bidb_kv_idx}; - } else { - cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; - bidb_kv_idx_prev = bidb_kv_idx; - n_block_idx_prev = n_block_idx; - return indices; - } - }; - - CUTLASS_DEVICE - TensorKVPtr compute_K_ptr() { - Tensor tPrKPtr = make_tensor(Shape>{}); - #pragma unroll - for (int i = 0; i < kPageEntryPerThread; ++i) { - auto [page, page_offset] = tPrPageOffset[i]; - tPrKPtr[i] = &mK_paged(page_offset, _0{}, page); - } - return tPrKPtr; - }; - - CUTLASS_DEVICE - void compute_V_ptr() { - #pragma unroll - for (int i = 0; i < kPageEntryPerThread; ++i) { - auto [page, page_offset] = tPrPageOffset[i]; - tPrVPtr[i] = &mV_paged(page_offset, _0{}, page); - } - }; - - template - CUTLASS_DEVICE - void load_K(const int n_block, TensorK &&sK) { - // Do we need bound check to make sure the row doesn't go above kBlockN - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; - - Tensor tPrKPtr = compute_K_ptr(); - - // Only for index calculation, since all the indices of thread 0 are known at compile time - auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); - - // We want to use the row indices of thread0 to compare, since that is known at compile time. - // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) - int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN - ? seqlen_k - n_block * kBlockN - : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN))); - #pragma unroll - for (int m = 0; m < size<1>(tKsK); ++m) { - bool const should_load = EvenN - ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit) - : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; - Element const* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); - Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); - Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); - if (should_load) { - #pragma unroll - for (int k = 0; k < size<2>(tKsK); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k)); - } - } // Don't need to clear out the rest of the smem since we'll mask out the scores anyway - } - }; - - template - CUTLASS_DEVICE - void load_V(const int n_block, TensorV &&sV) { - // Do we need bound check to make sure the row doesn't go above kBlockN - static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; - - if constexpr (KV_Same_Iter) { compute_V_ptr(); } - // Only for index calculation, since all the indices of thread 0 are known at compile time - auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); - Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); - Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); - - int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); - #pragma unroll - for (int m = 0; m < size<1>(tVsV); ++m) { - // Faster to rely on the cp.async to clear smem that are out of bound, - // rather than calling cute::clear directly. - // We have to be careful not to write to smem past `kBlockN` if !EvenN. - // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { - bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; - Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); - Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); - #pragma unroll - for (int k = 0; k < size<2>(tVsV); ++k) { - int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); - } - } - } - if constexpr (!KV_Same_Iter) { compute_V_ptr(); } - }; - - template - CUTLASS_DEVICE - void store_K(const int n_block, TensorK &&tKrK) { - Tensor tPrKPtr = compute_K_ptr(); - // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading) - // Only for index calculation, since all the indices of thread 0 are known at compile time - auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); - - GmemTiledCopyKVStore gmem_tiled_copy_kv_store; - // We want to use the row indices of thread0 to compare, since that is known at compile time. - // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) - // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); - int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); - // if (threadIdx.x == 128) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); } - #pragma unroll - for (int m = 0; m < size<1>(tKrK); ++m) { - bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; - Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); - Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); - Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); - if (should_load) { - #pragma unroll - for (int k = 0; k < size<2>(tKrK); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (tKpK(_0{}, k)) { - cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki)); - } - } - } - } - }; - - template - CUTLASS_DEVICE - void store_V(const int n_block, TensorV &&tVrV) { - if constexpr (KV_Same_Iter) { compute_V_ptr(); } - // Only for index calculation, since all the indices of thread 0 are known at compile time - auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); - Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); - - GmemTiledCopyKVStore gmem_tiled_copy_kv_store; - int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); - #pragma unroll - for (int m = 0; m < size<1>(tVrV); ++m) { - bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; - Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); - Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); - if (should_load) { - #pragma unroll - for (int k = 0; k < size<2>(tVrV); ++k) { - int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (tVpV(_0{}, k)) { - cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); - } - } - } - } - if constexpr (!KV_Same_Iter) { compute_V_ptr(); } - }; - - -}; - -} // namespace flash diff --git a/flash-attn/rotary.h b/flash-attn/rotary.h deleted file mode 100644 index aa3602cc79557d720d3bdd3c500be0e895a88ba5..0000000000000000000000000000000000000000 --- a/flash-attn/rotary.h +++ /dev/null @@ -1,489 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void -apply_rotary_interleaved(Tensor &rK, - Tensor const &rCos, - Tensor const &rSin) { - CUTE_STATIC_ASSERT_V(rank(rK) == _1{}); - CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); - CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); - CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); - static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2); - static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor K_fp32 = make_tensor_like(rK); - convert_type_out(rK, K_fp32); - Tensor cos_fp32 = make_tensor_like(rCos); - convert_type_out(rCos, cos_fp32); - Tensor sin_fp32 = make_tensor_like(rSin); - convert_type_out(rSin, sin_fp32); - #pragma unroll - for (int i = 0; i < size<0>(K_fp32) / 2; ++i) { - float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i]; - float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i]; - K_fp32[2 * i] = real; - K_fp32[2 * i + 1] = imag; - } - convert_type_out(K_fp32, rK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void -apply_rotary_contiguous(Tensor &rK_left, - Tensor &rK_right, - Tensor const &rCos, - Tensor const &rSin) { - CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{}); - CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{}); - CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); - CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); - CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right)); - CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos)); - CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); - static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor K_left_fp32 = make_tensor_like(rK_left); - convert_type_out(rK_left, K_left_fp32); - Tensor K_right_fp32 = make_tensor_like(rK_right); - convert_type_out(rK_right, K_right_fp32); - Tensor cos_fp32 = make_tensor_like(rCos); - convert_type_out(rCos, cos_fp32); - Tensor sin_fp32 = make_tensor_like(rSin); - convert_type_out(rSin, sin_fp32); - #pragma unroll - for (int i = 0; i < size<0>(K_left_fp32); ++i) { - float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i]; - float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i]; - K_left_fp32[i] = real; - K_right_fp32[i] = imag; - } - convert_type_out(K_left_fp32, rK_left); - convert_type_out(K_right_fp32, rK_right); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Rotary { - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each - // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. - // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved - // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will - // load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); - static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; - static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. - static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); - - using LayoutAtom = Layout, Int>, - Stride, _1>>; - using TiledCopyQK = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - LayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - using GmemTiledCopyRotary = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - LayoutAtom{}, - Layout>>{})); // Val layout, 4 or 8 vals per store - using GmemTiledCopyRotaryCont = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - LayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - - using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) - using StrideRotary = cute::Stride; - - using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0))); - using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0))); - using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); - using TensortRpR = decltype(make_tensor(make_shape(size<2>(TensortRcR{})))); - using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); - using TensortRpRCont = decltype(make_tensor(make_shape(size<2>(TensortRcRCont{})))); - using TensormR = decltype(make_tensor( - make_gmem_ptr((Element const*)nullptr), - ShapeRotary{}, - make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}))); - using TensortRgR = decltype( - GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor( - make_gmem_ptr((Element const*)nullptr), - make_shape(Int{}, Int{}, int(0)), - make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); - using TensortRgRCont = decltype( - GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor( - make_gmem_ptr((Element const*)nullptr), - make_shape(Int{}, Int{}, int(0)), - make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); - - GmemTiledCopyRotary gmem_tiled_copy_rotary; - GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont; - bool const is_rotary_interleaved; - int const rotary_dim; - int const thread_idx; - int const max_seqlen; - GmemThrCopyRotary const gmem_thr_copy_rotary; - GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont; - TensortRpR tRpR; - TensortRpRCont tRpRCont; - TensormR mCos, mSin; - TensortRgR tRgCos, tRgSin; - TensortRgRCont tRgCosCont, tRgSinCont; - - CUTLASS_DEVICE - Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_, - Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_, - bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx) - : is_rotary_interleaved(is_rotary_interleaved) - , rotary_dim(get<1>(shape_rotary) * 2) - , thread_idx(thread_idx) - , max_seqlen(max_seqlen) - , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx)) - , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx)) - - { - auto stride_rotary_cos = make_stride(cute::conditional_return(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_)); - auto stride_rotary_sin = make_stride(cute::conditional_return(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_)); - mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos); - mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin); - Tensor gCos = local_tile(mCos, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) - Tensor gSin = local_tile(mSin, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) - tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos); - tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin); - Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) - Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR); - tRpR = make_tensor(make_shape(size<2>(tRcR))); - #pragma unroll - for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); } - Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR); - tRpRCont = make_tensor(make_shape(size<2>(tRcRCont))); - #pragma unroll - for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); } - }; - - template - CUTLASS_DEVICE - auto load_cos_sin(int const block) { - using GmemTiledCopyRo = std::conditional_t; - auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); - Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); - Tensor tRgCosCur = cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, block); - Tensor tRgSinCur = cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, block); - // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way - Tensor tRrCos = make_tensor_like(tRgCosCur); - Tensor tRrSin = make_tensor_like(tRgSinCur); - Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) - Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); - // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens - #pragma unroll - for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) { - if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) { - #pragma unroll - for (int k = 0; k < size<2>(tRrCos); ++k) { - if (tRpRCur(k)) { - cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k)); - cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k)); - } - } - } - } - return cute::make_tuple(tRrCos, tRrSin);; - } - - template - CUTLASS_DEVICE - auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) { - static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad; - using GmemTiledCopyRo = std::conditional_t; - auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); - Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); - // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way - Tensor tRrCos = make_tensor_like(cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, _0{})); - Tensor tRrSin = make_tensor_like(cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, _0{})); - int const qhead_per_khead = qhead_per_khead_divmod.divisor; - Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) - Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); - - // The main bottleneck here is actually instruction cache misses. - - // Similar to PagedKVNonTMA, it's expensive to compute the pointers. - // We split the work among threads loading the same row, then __shfl_sync the pointers. - static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); - Tensor tPrCosPtr = make_tensor(Shape>{}); - Tensor tPrSinPtr = make_tensor(Shape>{}); - #pragma unroll - for (int i = 0; i < NumPtrPerThread; ++i) { - int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{})); - int const idx = block * kBlockMN + row; - int row_actual = qhead_per_khead_divmod.divide(idx); - tPrCosPtr[i] = &mCos(row_actual, _0{}); - tPrSinPtr[i] = &mSin(row_actual, _0{}); - } - - #pragma unroll - for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) { - int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{})); - Element const* cos_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Element const* sin_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - if (idx < max_seqlen * qhead_per_khead) { - Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape>{}), - Shape>{}); - Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape>{}), - Shape>{}); - #pragma unroll - for (int k = 0; k < size<2>(tRgCos); ++k) { - int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur); - if (tRpRCur(k)) { - cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k)); - cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k)); - } - } - } - } - return cute::make_tuple(tRrCos, tRrSin); - } - - template - CUTLASS_DEVICE - void - apply_Q_interleaved(TensorsQ &sQ, // (kBlockM, kHeadDim) - TensortRrR const &tRrCos, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary - TensortRrR const &tRrSin, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary - int const m_block, int const qhead_per_khead=1) - { - TiledCopyQK tiled_copy_q; - auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); - Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ); - Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); - - CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); - CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos)); - CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos)); - CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin)); - CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin)); - CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); - static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2); - static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - - #pragma unroll - for (int m = 0; m < size<1>(tQsQ); ++m) { - if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { - #pragma unroll - for (int k = 0; k < size<2>(tQsQ); ++k) { - if (tRpR(k)) { - Tensor rQ = make_fragment_like(tQsQ(_, m, k)); - cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ); - apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k)); - cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k)); - } - } - } - } - }; - - template - CUTLASS_DEVICE - void - apply_Q_contiguous(TensorsQ &sQ, // (kBlockM, kHeadDim) - TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont - TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont - int const m_block, int const qhead_per_khead=1) - { - TiledCopyQK tiled_copy_q; - auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); - Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int>{}); - Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); - - CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); - CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont)); - CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont)); - CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont)); - CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont)); - CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); - CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont)); - static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - - #pragma unroll - for (int m = 0; m < size<1>(tQcQ); ++m) { - int const row = get<0>(tQcQ(_0{}, m, _0{})); - if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { - #pragma unroll - for (int k = 0; k < size<2>(tQcQ); ++k) { - int const col = get<1>(tQcQ(_0{}, _0{}, k)); - if (col < rotary_dim / 2) { - int const col_idx_left = col / kGmemElemsPerLoad; - int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad); - Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left)); - cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left); - Tensor rQ_right = make_fragment_like(rQ_left); - cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right); - apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); - cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left)); - cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right)); - } - } - } - } - }; - - template - CUTLASS_DEVICE - void - apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) - TensorgK &gK, // (kBlockN, kHeadDim) - TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV - TensortRrR const &tRrCos, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary - TensortRrR const &tRrSin, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary - TensorKPtr const &tPrKPtr, - int const n_block) - { - TiledCopyQK tiled_copy_k; - auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); - Tensor tKsK = gmem_thr_copy_q.partition_S(sK); - Tensor tKgK = gmem_thr_copy_q.partition_S(gK); - Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); - - CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); - CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos)); - CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos)); - CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin)); - CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin)); - CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); - static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); - static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKVNonTMA) { - static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); - } - - #pragma unroll - for (int m = 0; m < size<1>(tKsK); ++m) { - int const row = get<0>(tKcK(_0{}, m, _0{})); - auto mK_cur_copy = [&] { - if constexpr (PagedKVNonTMA) { - Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); - Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); - return cute::tiled_divide(mK_cur, Shape>{}); - } else { - return nullptr; - } - }(); - if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { - #pragma unroll - for (int k = 0; k < size<2>(tKsK); ++k) { - if (tKpK(k)) { - Tensor rK = make_fragment_like(tKsK(_, m, k)); - cute::copy(tiled_copy_k, tKsK(_, m, k), rK); - if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } - if constexpr (!PagedKVNonTMA) { - cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); - } else { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki)); - } - } - } - } - } - }; - - template - CUTLASS_DEVICE - void - apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) - TensorgK &gK, // (kBlockN, kHeadDim) - TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV - TensortRrR const &tRrCosCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont - TensortRrR const &tRrSinCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont - TensorKPtr const &tPrKPtr, - int const n_block, int const max_k) - { - TiledCopyQK tiled_copy_k; - auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); - Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int>{}); - Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int>{}); - Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); - - CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); - CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); - CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont)); - CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont)); - CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont)); - CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont)); - CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); - CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); - static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKVNonTMA) { - static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); - } - - const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad; - const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad; - #pragma unroll - for (int m = 0; m < size<1>(tKcK); ++m) { - int const row = get<0>(tKcK(_0{}, m, _0{})); - Tensor gK_cur_copy = [&] { - if constexpr (PagedKVNonTMA) { - Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); - Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); - return cute::tiled_divide(mK_cur, Shape>{}); - } else { - return gK_copy(_, row, _); - } - }(); - if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { - #pragma unroll - for (int k = 0; k < size<2>(tKcK); ++k) { - if (tKpK(k)) { - int const col = get<1>(tKcK(_0{}, _0{}, k)); - bool rotate = col < rotary_dim / 2; - int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad; - int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2); - Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left)); - cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left); - Tensor rK_right = make_fragment_like(rK_left); - cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right); - if (rotate) { - apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); - } - cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left)); - if (col_idx_right * kGmemElemsPerLoad < max_k) { - cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right)); - } - } - } - } - } - }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/flash-attn/seqlen.h b/flash-attn/seqlen.h deleted file mode 100644 index 5547238b34892248ac2e2b2c464556136e304468..0000000000000000000000000000000000000000 --- a/flash-attn/seqlen.h +++ /dev/null @@ -1,95 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -namespace flash { - -// We consolidate all the info related to sequence length here. This is so that we can do all -// the gmem reads once at the beginning of each tile, rather than having to repeat these reads -// to compute various things like n_block_min, n_block_max, etc. - -template -struct SeqlenInfo { - - int const offset, offset_padded; - int const seqlen; - - CUTLASS_DEVICE - SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused) - : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb]) - , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock) - , seqlen(!Varlen - ? seqlen_static - : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static))) - { - } - -}; - -template -struct SeqlenInfoQK { - - int const offset_q, offset_k, offset_q_padded; - int const seqlen_q, seqlen_k; - - CUTLASS_DEVICE - SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, - int const* const cu_seqlens_q, int const* const cu_seqlens_k, - int const* const seqused_q, int const* const seqused_k - ) - : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) - , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb]) - // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch - // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence. - // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM - // However, the start must align to multiples of kBlockM. - , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM) - , seqlen_q(!Varlen - ? seqlen_q_static - : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) - , seqlen_k(!Varlen - ? seqlen_k_static - : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static))) - { - } - -}; - -template -struct SeqlenInfoQKNewK { - - static_assert(!(AppendKV && !Varlen), "AppendKV is only supported with Varlen"); - - int const leftpad_k; - int const offset_q, offset_k, offset_k_new; - int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; - - CUTLASS_DEVICE - SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, - int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, - int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, - int const* const seqlens_rotary - ) - : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) - , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) - , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k) - , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb]) - , seqlen_q(!Varlen - ? seqlen_q_static - : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) - , seqlen_k_og(!Varlen - ? seqlen_k_static - : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k) - , seqlen_k_new(!AppendKV - ? 0 - : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) - , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) - , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) - { - } - -}; - -} // namespace flash diff --git a/flash-attn/sm90_pipeline_no_cluster.hpp b/flash-attn/sm90_pipeline_no_cluster.hpp deleted file mode 100644 index 65a3d1554b362260207f295630433a95b7e5ac39..0000000000000000000000000000000000000000 --- a/flash-attn/sm90_pipeline_no_cluster.hpp +++ /dev/null @@ -1,99 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -namespace cutlass { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads -// signaling the barrier during consumer_release. This causes a perf regression in FA3 -// forward pass (especially hdim 128 causal). We instead reimplement the version of -// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. -// -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 -template > -class PipelineTmaAsyncNoCluster: public Base { -public: - using FullBarrier = typename Base::FullBarrier; - using EmptyBarrier = typename Base::EmptyBarrier; - static constexpr uint32_t Stages = Stages_; - using PipelineState = typename Base::PipelineState; - - using SharedStorage = typename Base::SharedStorage; - using ThreadCategory = typename Base::ThreadCategory; - using Params = typename Base::Params; - - static - CUTLASS_DEVICE - void - init_barriers(SharedStorage& storage, Params params) { - int warp_idx = canonical_warp_idx_sync(); - bool is_initializing_warp = (warp_idx == 0); - if (is_initializing_warp) { - // Barrier FULL and EMPTY init - constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; - uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; - - cutlass::arch::detail::initialize_barrier_array_pair_aligned( - storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); - } - cutlass::arch::fence_barrier_init(); - } - - template - CUTLASS_DEVICE - PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) - : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - int warp_idx = canonical_warp_idx_sync(); - int lane_predicate = cute::elect_one_sync(); - - static_assert(cute::is_same_v || cute::is_same_v); - static_assert(cute::is_same_v || cute::is_same_v); - if constexpr (cute::is_same_v) { - init_barriers(storage, params); - } - - } - - // Constructor - template - CUTLASS_DEVICE - PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape) - : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } - - template - CUTLASS_DEVICE - PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) - : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - -private: - EmptyBarrier* const empty_barrier_ptr_ = nullptr; - - // Consumer signalling Producer of completion - // Ensures all blocks in the Same Row and Column get notifed. - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip = false) { - empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & (!skip) /*is_signaling_thread*/); - } - -}; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // end namespace cutlass diff --git a/flash-attn/softmax.h b/flash-attn/softmax.h deleted file mode 100644 index e167b2e4955979c561c87786c25f22a2ddfc5727..0000000000000000000000000000000000000000 --- a/flash-attn/softmax.h +++ /dev/null @@ -1,191 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include - -#include - -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ni++) { - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni)); - } - } -} - -template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++) { - dst(i) = Allreduce<4>::run(src(i), op); - } -} - -template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); -} - -template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); -} - -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); - if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256]. - // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow. - static constexpr float max_offset = float(Max_offset); // We can only template on int, not float - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = Check_inf - ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset) - : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)). This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax { - - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; - float const softmax_scale_log2; - - CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {}; - - template - __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); - TensorT scores_scale; - if constexpr (Is_first) { - flash::template reduce_max(scores, row_max); - cute::fill(scores_scale, 1.f); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - flash::template reduce_max(scores, row_max); - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = !Check_inf - ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); - scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - row_sum(mi) *= scores_scale(mi); - } - } - return scores_scale; - }; - - template - __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); - flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); - // We don't do the reduce across threads here since we don't need to use the row_sum. - // We do that reduce at the end when we need to normalize the softmax. - flash::reduce_sum(scores, row_sum); - }; - - __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT scores_scale; - #pragma unroll - for (int mi = 0; mi < size(row_sum); ++mi) { - float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; - scores_scale(mi) = inv_sum * final_scale; - // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. - if constexpr (Max_offset != 0) { - static constexpr float sum_scale = 1.f / float(1 << Max_offset); - sum *= sum_scale; - } - row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); - } - return scores_scale; - }; - - __forceinline__ __device__ TensorT finalize_aux(TensorT const& tSrSAux, float const final_scale=1.f) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT scores_scale; - #pragma unroll - for (int mi = 0; mi < size(row_sum); ++mi) { - if (row_max(mi) == -INFINITY) { row_max(mi) = 0.f; } - const float max_scaled = row_max(mi) * softmax_scale_log2 - Max_offset; - float sum = row_sum(mi) + exp2f(float(M_LOG2E) * tSrSAux(mi) - max_scaled); - float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; - scores_scale(mi) = inv_sum * final_scale; - // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. - if constexpr (Max_offset != 0) { - static constexpr float sum_scale = 1.f / float(1 << Max_offset); - sum *= sum_scale; - } - row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); - } - return scores_scale; - }; - - template - __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } - } - }; - -}; - -} // namespace flash diff --git a/flash-attn/static_switch.h b/flash-attn/static_switch.h deleted file mode 100644 index 4701fa202ea68246fe3e19ea3a32dcbced624c00..0000000000000000000000000000000000000000 --- a/flash-attn/static_switch.h +++ /dev/null @@ -1,192 +0,0 @@ -// Inspired by -// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -// For TORCH_CHECK -#include - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -// - -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() - -#ifdef FLASHATTENTION_DISABLE_LOCAL - #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ - [&] { \ - constexpr static bool LOCAL_CONST_NAME = false; \ - if (CAUSAL_COND) { \ - constexpr static bool CAUSAL_CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool CAUSAL_CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() -#else - #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ - [&] { \ - if (CAUSAL_COND) { \ - constexpr static bool CAUSAL_CONST_NAME = true; \ - constexpr static bool LOCAL_CONST_NAME = false; \ - return __VA_ARGS__(); \ - } else if (LOCAL_COND) { \ - constexpr static bool CAUSAL_CONST_NAME = false; \ - constexpr static bool LOCAL_CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool CAUSAL_CONST_NAME = false; \ - constexpr static bool LOCAL_CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() -#endif - -#ifdef FLASHATTENTION_DISABLE_SOFTCAP - #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define SOFTCAP_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_PAGEDKV - #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define PAGEDKV_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_SPLIT - #define SPLIT_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define SPLIT_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_APPENDKV - #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define APPENDKV_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_PACKGQA - #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define PACKGQA_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_VARLEN - #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#elif defined(FLASHATTENTION_VARLEN_ONLY) - #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - TORCH_CHECK(COND, "This flash attention build only supports varlen " \ - "(for build size reasons)."); \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - }() -#else - #define VARLEN_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_CLUSTER - #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define CLUSTER_SWITCH BOOL_SWITCH -#endif - -#ifdef FLASHATTENTION_DISABLE_SM8x - #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ - [&] { \ - constexpr static int ARCH_NAME = 90; \ - return __VA_ARGS__(); \ - }() -#else - #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ - [&] { \ - if (ARCH == 86 || ARCH == 89) { \ - constexpr static int ARCH_NAME = 86; \ - return __VA_ARGS__(); \ - } else if (ARCH < 90) { \ - constexpr static int ARCH_NAME = 80; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static int ARCH_NAME = 90; \ - return __VA_ARGS__(); \ - } \ - }() -#endif - -#ifndef FLASHATTENTION_ENABLE_VCOLMAJOR - #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define VCOLMAJOR_SWITCH BOOL_SWITCH -#endif - -#define HEADDIM_SWITCH(HEADDIM, ...) \ - [&] { \ - if (HEADDIM == 64) { \ - constexpr static int kHeadSize = 64; \ - return __VA_ARGS__(); \ - } else if (HEADDIM == 96) { \ - constexpr static int kHeadSize = 96; \ - return __VA_ARGS__(); \ - } else if (HEADDIM == 128) { \ - constexpr static int kHeadSize = 128; \ - return __VA_ARGS__(); \ - } else if (HEADDIM == 96) { \ - constexpr static int kHeadSize = 96; \ - return __VA_ARGS__(); \ - } else if (HEADDIM == 256) { \ - constexpr static int kHeadSize = 256; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/flash-attn/tile_scheduler.hpp b/flash-attn/tile_scheduler.hpp deleted file mode 100644 index 53651d5c8484da437c5556aaa052d227fc0974a1..0000000000000000000000000000000000000000 --- a/flash-attn/tile_scheduler.hpp +++ /dev/null @@ -1,599 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cutlass/fast_math.h" -#include "cutlass/arch/barrier.h" - -#include "named_barrier.hpp" -#include "utils.h" - -namespace flash { - -/////////////////////////////////////////////////////////////////////////////// - -// Host side kernel arguments -struct TileSchedulerArguments { - // num_head is num_head_q if not PackGQA, else num_head_k - int const num_blocks, num_head, num_batch, num_splits; - int const qhead_per_khead; - int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr - int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling - int* const tile_count_semaphore = nullptr; - int const* const cu_seqlens = nullptr; - int const* const seqused = nullptr; - // int const* const num_m_blocks_ptr = nullptr; - int const* const num_splits_dynamic_ptr = nullptr; -}; - -/////////////////////////////////////////////////////////////////////////////// - -template -class SingleTileScheduler { - -public: - - using SharedStorage = int; - - // Device side kernel params - struct Params { - int const num_blocks, num_head, num_batch, num_splits; - int const qhead_per_khead; - int const seqlen; - cutlass::FastDivmod nsplits_divmod; - int const* const cu_seqlens; - int const* const seqused; - int const* const num_splits_dynamic_ptr = nullptr; - }; - - static Params - to_underlying_arguments(TileSchedulerArguments const& args) { - assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); - assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits - return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, - args.qhead_per_khead, args.seqlen, - cutlass::FastDivmod(!Split ? 1 : args.num_splits), - !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, - args.num_splits_dynamic_ptr}; - } - - static dim3 - get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; - } - - struct WorkTileInfo { - int block_idx = 0; - int bidh = 0; - int bidb = 0; - int split_idx = 0; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return bidb >= 0; - } - - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; - } - - }; - - CUTLASS_DEVICE - SingleTileScheduler(SharedStorage* const smem_scheduler) { } - - template - CUTLASS_DEVICE - WorkTileInfo - get_initial_work(Params const& params) const { - WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; - if constexpr (Split) { - int split_idx; - work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); - work_info.split_idx = split_idx; - } - bool is_valid_tile = true; - if constexpr (Varlen) { - int seqlen = params.seqused - ? params.seqused[work_info.bidb] - : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - is_valid_tile = work_info.block_idx * kBlock < seqlen; - } - if constexpr (Varlen && Split) { - int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; - is_valid_tile &= work_info.split_idx < num_splits_dynamic; - // Use the top 16 bits to store num_splits - work_info.split_idx |= (num_splits_dynamic << 16); - } - work_info.bidb = is_valid_tile ? work_info.bidb : -1; - return work_info; - } - - CUTLASS_DEVICE - void - init_consumer() const {} - - CUTLASS_DEVICE - void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} - - template - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {0, 0, -1, 0}; - } - -}; - -/////////////////////////////////////////////////////////////////////////////// - -template -class StaticPersistentTileScheduler { - -public: - - using SharedStorage = int; - - // Device side kernel params - struct Params { - int total_blocks; - cutlass::FastDivmod m_block_divmod, head_divmod; - cutlass::FastDivmod nsplits_divmod; - }; - - static Params - to_underlying_arguments(TileSchedulerArguments const& args) { - return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), - cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), - cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; - } - - static dim3 - get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(num_sm)}; - } - - struct WorkTileInfo { - int tile_idx; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return tile_idx < params.total_blocks; - } - - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - int block, bidh, bidb; - bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); - int split_idx = 0; - if constexpr (Split) { - bidh = params.nsplits_divmod.divmod(split_idx, bidh); - } - return {block, bidh, bidb, split_idx}; - } - - }; - - CUTLASS_DEVICE - StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; - - template - CUTLASS_DEVICE - WorkTileInfo - get_initial_work(Params const& params) const { - return {int(blockIdx.x)}; - } - - CUTLASS_DEVICE - void - init_consumer() const {} - - CUTLASS_DEVICE - void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} - - template - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {current_work.tile_idx + int(gridDim.x)}; - } - -}; - -template -class DynamicPersistentTileScheduler { - - // This scheduler targets the causal (or local) case where each tile takes different - // amount of time. We use longest-processing-time-first scheduling: - // the longest remaining tile is assigned to the first SM that's free. - // SM indicates they are free by incrementing a semaphore. - // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling - // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads. - // This is the L2 swizzling part. The size of each section is precomputed based on the - // size of K & V and the L2 cache size. - - static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); - static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; - -public: - using SharedStorage = int; - -protected: - SharedStorage* const tile_count_smem; - -public: - - // Device side kernel params - struct Params { - int const total_blocks; - cutlass::FastDivmod const m_block_divmod, head_divmod; - cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; - cutlass::FastDivmod const l2_minor_residual_divmod; - int const num_hb_quotient; - int* const tile_count_semaphore; - }; - - static Params - to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; - int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V - // Swizzle is the size of each "section". Round swizzle to a power of 2 - // If not PackGQA already, the size of each section can increase by qhead_per_khead - // Need to be careful about the case where only one head will fit - int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << cutlass::find_log2(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; - int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); - // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder); - assert(args.tile_count_semaphore != nullptr); - return {num_split_blocks * args.num_head * args.num_batch, - cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), - cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), - // don't divide by 0 - cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), - (args.num_head * args.num_batch) / swizzle, - args.tile_count_semaphore}; - } - - static dim3 - get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(num_sm)}; - } - - struct WorkTileInfo { - int tile_idx; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return tile_idx < params.total_blocks; - } - - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - int block, bidh, bidb; - int l2_mod, bidhb, bidhb_residual; - bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - if (bidhb < params.num_hb_quotient) { - block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); - } else { - block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); - } - bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); - int split_idx = 0; - if constexpr (Split) { - split_idx = params.m_block_divmod.divmod(block, block); - } - // Longest-processing-time-first - block = params.m_block_divmod.divisor - 1 - block; - return {block, bidh, bidb, split_idx}; - } - - }; - - CUTLASS_DEVICE - DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; - - template - CUTLASS_DEVICE - WorkTileInfo - get_initial_work(Params const& params) const { - return {int(blockIdx.x)}; - } - - CUTLASS_DEVICE - void - init_consumer() const { - if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { - flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty - } - } - - CUTLASS_DEVICE - void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { - if (threadIdx.x % NumProducerThreads == 0) { - current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); - } - } - - template - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - if constexpr (IsProducerWarp) { - // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 - int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty - if (threadIdx.x % NumProducerThreads == 0) { - *tile_count_smem = current_work.tile_idx; - } - flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull - return {new_tile_idx}; - } else { - flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull - int tile_idx = *tile_count_smem; - flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty - return {tile_idx}; - } - } - -}; - -template -class VarlenDynamicPersistentTileScheduler { - - static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); - static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; - -public: - using SharedStorage = int4; - -protected: - SharedStorage* const work_info_smem; - -public: - - // Device side kernel params - struct Params { - int num_head, num_batch; - int const qhead_per_khead; - int const seqlen; - cutlass::FastDivmod head_divmod; - cutlass::FastDivmod nsplits_divmod; - int* const tile_count_semaphore; - int const* const cu_seqlens; - int const* const seqused; - // int* const num_m_blocks_ptr; - int const* const num_splits_dynamic_ptr; - }; - - static Params - to_underlying_arguments(TileSchedulerArguments const& args) { - // If Split, for the purpose of scheduling, we pretend that instead there are - // (args.num_splits * args.num_head) number of heads. - assert(args.tile_count_semaphore != nullptr); - assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx - assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits - return {args.num_head, args.num_batch, - args.qhead_per_khead, args.seqlen, - cutlass::FastDivmod(args.num_head), - cutlass::FastDivmod(!Split ? 1 : args.num_splits), - args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; - } - - static dim3 - get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(num_sm)}; - } - - struct WorkTileInfo { - int tile_idx, block, bidh, bidb; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } - return bidb < params.num_batch; - } - - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; - } else { - // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh); - uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; - int bidh_actual = reinterpret_cast(bidh_actual_u); - // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx - uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); - int split_idx = reinterpret_cast(split_idx_u); - // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - // if (threadIdx.x == 128) { - // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); - // } - return {block, bidh_actual, bidb, split_idx}; - } - } - }; - - CUTLASS_DEVICE - VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; - - CUTLASS_DEVICE - WorkTileInfo - tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks = [&] (int bidb_start) { - int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; - } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - } - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; - // ? params.num_m_blocks_ptr[batch_idx] : 0; - }; - - auto get_num_splits = [&] (int bidb_start) { - int batch_idx = lane + bidb_start; - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : (params.num_splits_dynamic_ptr - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor)) - : 0; - }; - - int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane - int num_splits = get_num_splits(current_work.bidb); - int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; - // Cumulative number of blocks for the next 31 batches - int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); - // Total number of blocks for the next 31 batches - int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - // Only the lower 16 bits are the actual bidh - int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); - int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes - if constexpr (Split) { - int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; - group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); - } - int bidb = current_work.bidb; - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); - // } - // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } - while (group_end_tile <= next_tile_idx) { - bidb += cutlass::NumThreadsPerWarp - 1; - if (bidb >= params.num_batch) { - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); - // } - return {next_tile_idx, 0, 0, params.num_batch}; - } - num_m_blocks = get_num_m_blocks(bidb); - num_splits = get_num_splits(bidb); - num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; - num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); - m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - group_end_tile += m_blocks_in_group * params.num_head; - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); - // } - } - int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; - // The next problem to process is the first one that does not have ending tile position - // that is greater than or equal to tile index. - int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); - // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } - bidb += batch_idx_in_group; - num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); - if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - if constexpr (Split) { - int bidh_actual = bidh / num_splits; - int split_idx = bidh - bidh_actual * num_splits; - // TODO: idk why this gives wrong answer nondeterministically - // int bidh_actual, split_idx; - // split_idx = params.head_divmod.divmod(bidh_actual, bidh); - // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); - // } - bidh = reinterpret_cast(bidh_packed); - } - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; - } - - template - CUTLASS_DEVICE - WorkTileInfo - get_initial_work(Params const& params) const { - if constexpr (IsProducerWarp) { - WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); - if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); - } - flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull - return work_info; - } else { - return get_next_work(params, {0, 0, 0, 0}); - } - } - - CUTLASS_DEVICE - void - init_consumer() const { - // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that - } - - CUTLASS_DEVICE - void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { - if (threadIdx.x % NumProducerThreads == 0) { - current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); - } - } - - template - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - if constexpr (IsProducerWarp) { - // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 - int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; - work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); - flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty - if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); - } - flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull - return work_info; - } else { - flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull - int4 work_info = *work_info_smem; - flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty - return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; - } - } - -}; - -} // flash diff --git a/flash-attn/tile_size.h b/flash-attn/tile_size.h deleted file mode 100644 index b87a83afff8de3c350cc5b3f63c0c53329496848..0000000000000000000000000000000000000000 --- a/flash-attn/tile_size.h +++ /dev/null @@ -1,82 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} -constexpr std::tuple tile_size_fwd_sm90( - int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, - bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false, bool use_one_mma_wg=false) { - if (element_size == 2) { - if (headdim <= 64) { - // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; - // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why - // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 - if (headdim_v == 512) { - return {64, 64, false, false}; - } else if (headdim_v == 256) { - return {128, 112, true, false}; - } else { - // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; - return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; - } - // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen - // return {192, is_causal || is_local ? 192 : 176, true, false}; - } else if (headdim <= 96) { - return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; - } else if (headdim <= 128) { - if (use_one_mma_wg) { - return {64, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - } else { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - } - // {128, 192, false, false} and {192, 128, false, true} are quite good too - // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS - } else if (headdim <= 192) { - return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem - } else { - return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem - } - } else { - if (headdim <= 64) { - return {192, 160, true, true}; - } else if (headdim <= 96) { - return {192, 128, true, true}; - } else if (headdim <= 128) { - return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; - } else if (headdim <= 192) { - return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; - } else { - return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap - } - } -} - -// Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} -constexpr std::tuple tile_size_fwd_sm8x( - bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, - bool paged_kv=false, bool varlen_and_split=false, - bool softcap=false, bool append_kv=false) { - if (element_size == 2) { - if (headdim <= 64) { - return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false}; - } else if (headdim <= 96) { - return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false}; - } else if (headdim <= 128) { - bool const use_8_warps = sm86_or_89 | varlen_and_split; - return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps}; - } else if (headdim <= 192) { - bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv; - return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64}; - } else { - return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv}; - } - } else { - // Placeholder for now - return {128, 64, 8, 2, false}; - } -} diff --git a/flash-attn/utils.h b/flash-attn/utils.h deleted file mode 100644 index 3719eab9202971c679f66865a1629d8e2e553fa8..0000000000000000000000000000000000000000 --- a/flash-attn/utils.h +++ /dev/null @@ -1,679 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -#endif - -#include - -#include -#include -#include -#include - -#include "cuda_check.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// A wrapper for the kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55 -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - Kernel::operator()(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm80_to_sm89 : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890) - Kernel::operator()(std::forward(args)...); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -// This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ __forceinline__ T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ __forceinline__ T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) -template -CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) { - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = acc_layout; - if constexpr (!Transposed) { - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); - } else { - return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); - } - - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - if constexpr (!Transposed) { - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - } else { - return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. -// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) -// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) -template -CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) { - using X = Underscore; - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); - if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { - auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) - return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); - } else { - static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); - static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); - static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); - auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) - // This combines the first two modes (<0, 0> and <0, 1>) into one mode. - // Will require register shuffling later to be correct. - return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), - get<1>(acc_layout), - coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) - // This combination is right but doesn't work with register shuffling. - // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), - // get<1>(acc_layout), - // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); - } - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE auto convert_type_unsafe(Tensor const &tensor) { - using From_type = typename Engine::value_type; - static constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); - // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not - // inline this function, then the memory might not be valid. -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { - // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. - using From_type = typename Engine::value_type; - using To_type = typename EngineOut::value_type; - static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); - static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); - Tensor frag = recast const>(tensor); - Tensor out_frg = recast>(out); - static_assert(size(frag) == size(out_frg)); - cutlass::NumericArrayConverter convert_op; - #pragma unroll - for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Blocks until all but N previous cp.async.commit_group operations have committed. -// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all -// (which is equivalent to commit_group then wait_group 0). -// Instead we just call cp.async.wait_group 0, which is slightly faster. -// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 -template -CUTE_HOST_DEVICE -void cp_async_wait() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE -auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { - if constexpr (A) { - return mma.partition_fragment_A(tensor0); - } else { - return mma.partition_fragment_B(tensor0); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { - if constexpr (M_slice >= 0) { - static constexpr int MMA_M = decltype(size<1>(tCrC))::value; - static_assert(M_slice < MMA_M); - // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N) - Tensor tCrC_slice = cute::logical_divide(tCrC, Shape>{})(_, make_coord(Int{}, _), _); - if constexpr (!SwapAB) { - Tensor tCrA_slice = cute::logical_divide(tCrA, Shape>{})(_, make_coord(Int{}, _), _); - gemm(tiled_mma, tCrA_slice, tCrB, tCrC_slice); - } else { - Tensor tCrB_slice = cute::logical_divide(tCrB, Shape>{})(_, make_coord(Int{}, _), _); - gemm(tiled_mma, tCrA, tCrB_slice, tCrC_slice); - } - } else { - constexpr bool Is_RS = !cute::is_base_of::value; - // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const - if constexpr (Is_RS) { - if constexpr (!SwapAB) { - warpgroup_fence_operand(const_cast(tCrA)); - } else { - warpgroup_fence_operand(const_cast(tCrB)); - } - } - warpgroup_fence_operand(tCrC); - warpgroup_arrive(); - if constexpr (zero_init) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); - static constexpr int kMaxKIters = 16; - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { - if constexpr (!SwapAB) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - } else { - cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - // In the case of large kNumKIters, the compiler chooses to store the smem addresses - // in registers, causing spills. This loop forces the compiler to recompute the addresses. - if constexpr (kNumKIters > kMaxKIters) { - // This will always be zero, just a way to force the compiler to recompute the smem - // addresses. This results in USEL instructions. There's probably a better way to do this. - int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; - CUTLASS_PRAGMA_UNROLL - for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { - if constexpr (!SwapAB) { - cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); - } else { - cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } - warpgroup_commit_batch(); - if constexpr (wg_wait >= 0) { warpgroup_wait(); } - warpgroup_fence_operand(tCrC); - if constexpr (Is_RS) { - if constexpr (!SwapAB) { - warpgroup_fence_operand(const_cast(tCrA)); - } else { - warpgroup_fence_operand(const_cast(tCrB)); - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, - Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, - ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) { - if constexpr (SwapAB) { - gemm_sm80(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn); - } else { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } - } - if constexpr (!std::is_same_v) { - if (i == 0) { fn(); } - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, - ThrCopy smem_thr_copy_B) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { - static constexpr int rA = decltype(rank(tA))::value; - static constexpr int rB = decltype(rank(tB))::value; - static constexpr int rC = decltype(rank(tC))::value; - static_assert(rA == 3 && rB == 3 && rC == 3); - - if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; } - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tA); k_block++) { - cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); - atom.accumulate_ = decltype(atom.accumulate_)::One; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE constexpr -auto -to_tiled_mma_sm100_ts( - TiledMMA, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant>, - TAs...>, TMs...>) { - - return TiledMMA>, - TAs...>, TMs...>{}; -} - -template -CUTE_HOST_DEVICE constexpr -auto -to_tiled_mma_sm100_ts( - TiledMMA, - TAs...>, TMs...>) { - return TiledMMA, - TAs...>, TMs...>{}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void copy(TiledCopy const &tiled_copy, Tensor const &S, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, const int max_MN=0) { - // Decay TiledCopy to CopyAtom - auto copy_atom = static_cast(tiled_copy); - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); - auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN; - if constexpr (Is_even_MN || !Clear_OOB_MN) { - if (Is_even_MN || predicate_mn) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if constexpr (Is_even_K || !Clear_OOB_K) { - if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } - } else { // Clear_OOB_K == true && Is_even_K == false - // If copy traits can be transformed with a predicate value, do it, otherwise branch here - if constexpr (has_with_bool) { - cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k)); - } else { - if (predicate_K(k)) { - cute::copy(copy_atom, S(_, m, k), D(_, m, k)); - } else { - cute::clear(D(_, m, k)); - } - } - } - } - } - } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true - if constexpr (!has_with_bool) { - if (predicate_mn) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(copy_atom, S(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } else { - cute::clear(D(_, m, _)); - } - } else { // combine the mn predicate with the k predicate - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Byte permute and shuffle to match register layout of -// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II. -template -CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) { - // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits - static_assert(decltype(size<0, 0>(frag))::value == 4); - static_assert(decltype(size<0, 1>(frag))::value == 2); - static_assert(decltype(stride<0, 0>(frag))::value == 1); - static_assert(decltype(stride<0, 1>(frag))::value == 4); - static_assert(sizeof(typename Fragment::value_type) == 1); - - int quad_idx = threadIdx.x % 4; - bool lane_03 = quad_idx == 0 || quad_idx == 3; - int selector_upper = lane_03 ? 0x5410 : 0x1054; - int selector_lower = lane_03 ? 0x7632 : 0x3276; - - static constexpr int upper_map[4] = {0, 3, 1, 2}; - // static constexpr int lower_map[4] = {1, 2, 0, 3}; - - Tensor frag_64b = recast(frag); // ((1, 1, 2), MMA_M, MMA_N) - #pragma unroll - for (int i = 0; i < size(frag_64b); ++i) { - uint32_t upper = frag_64b[i].x; - uint32_t lower = frag_64b[i].y; - uint32_t upper0 = lane_03 ? upper : lower; - uint32_t lower0 = lane_03 ? lower : upper; - upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); - // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); - lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4); - frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper); - frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { - // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits - static_assert(decltype(size<0, 0>(frag))::value == 2); - static_assert(decltype(size<0, 1>(frag))::value == 2); - static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); - static_assert(decltype(stride<0, 0>(frag))::value == 1); - static_assert(sizeof(typename Fragment::value_type) == 4); - Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) - #pragma unroll - for (int mi = 0; mi < size<1>(frag_64b); ++mi) { - #pragma unroll - for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { - cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void permute_output_fp8(Fragment &out) { - // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits - static_assert(decltype(size<0, 0>(out))::value == 2); - static_assert(decltype(size<0, 1>(out))::value == 2); - static_assert(decltype(size<0, 2>(out))::value % 2 == 0); - static_assert(decltype(stride<0, 0>(out))::value == 1); - static_assert(sizeof(typename Fragment::value_type) == 4); - Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N)) - #pragma unroll - for (int mi = 0; mi < size<1>(frag); ++mi) { - #pragma unroll - for (int j = 0; j < size<0, 1>(frag); ++j) { - #pragma unroll - for (int i = 0; i < size<0, 2>(frag) / 2; ++i) { - cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi)); - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) { - // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits - static_assert(decltype(size<0, 0>(frag))::value == 2); - static_assert(decltype(size<0, 1>(frag))::value == 2); - static_assert(decltype(stride<0, 0>(frag))::value == 1); - static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4); - - int quad_idx = threadIdx.x % 4; - bool lane_03 = quad_idx == 0 || quad_idx == 3; - - static constexpr int upper_map[4] = {0, 2, 3, 1}; - // static constexpr int lower_map[4] = {2, 0, 1, 3}; - - // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } - using type2 = std::conditional_t; - Tensor frag_2 = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) - // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); } - #pragma unroll - for (int mi = 0; mi < size<1>(frag_2); ++mi) { - #pragma unroll - for (int j = 0; j < size<0, 1>(frag_2); ++j) { - #pragma unroll - for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) { - type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi); - type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi); - type2 upper0 = lane_03 ? upper : lower; - type2 lower0 = lane_03 ? lower : upper; - upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); - // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); - lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4); - frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0; - frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0; - } - } - } - // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_DEVICE void apply_softcap(Tensor &tensor, float const softcap){ - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); - } -} - -template -CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ - Tensor out = make_fragment_like(tensor); - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - out(i) = 1.f - (tensor(i) * tensor(i)); - } - return out; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_DEVICE T warp_prefix_sum(T val) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { - T partial_sum = __shfl_up_sync(0xffffffff, val, i); - if (lane >= i) { val += partial_sum; } - } - return val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_DEVICE T warp_shfl_get(T val, int src_lane) { - return __shfl_sync(0xffffffff, val, src_lane); -}; - -template -CUTE_DEVICE T warp_shfl_get_last(T val) { - return __shfl_sync(0xffffffff, val, cutlass::NumThreadsPerWarp - 1); -}; - -CUTE_DEVICE int warp_last_true_laneid(bool cond) { - return __popc(__ballot_sync(0xffffffff, cond)); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_DEVICE T warp_uniform(T a) { - return __shfl_sync(0xffffffff, a, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -CUTLASS_DEVICE -int canonical_warp_group_idx_nosync() { - return threadIdx.x / cutlass::NumThreadsPerWarpGroup; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/readme_example.py b/readme_example.py deleted file mode 100644 index 760be77de7ca1ef47d0e49a1cfdddd479cd39e9c..0000000000000000000000000000000000000000 --- a/readme_example.py +++ /dev/null @@ -1,47 +0,0 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "torch", -# "triton", -# "numpy", -# "kernels", -# ] -# /// - -import torch -from kernels import get_kernel - -# Load vllm-flash-attn3 via kernels library -vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3") - -# Access Flash Attention function -flash_attn_func = vllm_flash_attn3.flash_attn_func - -# Set device and seed for reproducibility -device = "cuda" -torch.manual_seed(42) -torch.cuda.manual_seed(42) - -# Parameters -batch_size = 2 -seqlen_q = 128 # Query sequence length -seqlen_k = 256 # Key sequence length -nheads = 8 # Number of attention heads -d = 64 # Head dimension - -# Create input tensors (Q, K, V) -q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.bfloat16) -k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16) -v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16) - -print(f"Query shape: {q.shape}") -print(f"Key shape: {k.shape}") -print(f"Value shape: {v.shape}") - -# Run Flash Attention 3 -output, lse = flash_attn_func(q, k, v, causal=True) - -print(f"\nOutput shape: {output.shape}") -print(f"LSE (log-sum-exp) shape: {lse.shape}") -print(f"\nAttention computation successful!") -print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}") \ No newline at end of file diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py deleted file mode 100644 index 3b4497732f8286cd0af882a633b2585fbfcdac05..0000000000000000000000000000000000000000 --- a/tests/test_flash_attn.py +++ /dev/null @@ -1,1110 +0,0 @@ -import os -import math -import itertools - -import pytest -import torch -import torch.nn.functional as F - -from einops import rearrange, repeat -try: - from flash_attn.layers.rotary import apply_rotary_emb -except ImportError: - apply_rotary_emb = None - -from padding import pad_input, unpad_input -from test_util import ( - attention_ref, - generate_qkv, - generate_random_padding_mask, -) - -from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_combine -from flash_attn import flash_attn_with_kvcache, get_scheduler_metadata - - -DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" -DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" -DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" -DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" -DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" -DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" -DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" -DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" -DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 -DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" -DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" -DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" -DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" -DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" - -COMPILED_HDIMS = ( - [] - + ([64] if not DISABLE_HDIM64 else []) - + ([96] if not DISABLE_HDIM96 else []) - + ([128] if not DISABLE_HDIM128 else []) - + ([192] if not DISABLE_HDIM192 else []) - + ([256] if not DISABLE_HDIM256 else []) -) - - -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) -# @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) -@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) -# @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -# @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) -# @pytest.mark.parametrize("V_colmajor", [False, True]) -@pytest.mark.parametrize("V_colmajor", [False]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64, 128, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) -# @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 1), - (64, 128), - (128, 192), - (256, 256), - (239, 1), - (799, 3), - (113, 203), - (113, 128), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (384, 256), - (640, 128), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (4096, 4096), - (4224, 4224), - ], -) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) -def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype -): - if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): - pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") - device = "cuda" - # set seed - torch.random.manual_seed(0) - # batch_size = 40 - # nheads = 16 - batch_size = 9 if seqlen_k <= 2048 else 2 - # batch_size = 1 - nheads = 6 - # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: - dv_vals = [d] - for dv in dv_vals: - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - else: - qv_ref = None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - # window_size = (-1, -1) if not local else (16, 0) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] - qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None - if V_colmajor: - v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() - # if qv is not None: - # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) - - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( - q, - k, - v, - causal=causal, - qv=qv, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - pack_gqa=pack_gqa, - num_splits=num_splits - ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: - g = torch.randn_like(out) - do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) - # import flash_attn_3_cuda - # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( - # g, - # q, - # k, - # v, - # out, - # lse, - # None, - # None, - # None, - # d ** (-0.5), - # causal, - # window_size[0], window_size[1], - # softcap, - # deterministic, - # 0, # sm_margin - # ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") - # assert (softmax_d - do_o).abs().max().item() <= 1e-5 - # assert dq_accum.abs().max().item() == 0.0 - - # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) - # P = torch.softmax(qk, -1) - # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) - # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) - # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) - # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) - dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - # breakpoint() - - - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol - - -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) -# @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) -@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) -# @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -# @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("add_unused_qkv", [False, True]) -# @pytest.mark.parametrize("add_unused_qkv", [True]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) -# @pytest.mark.parametrize("d", [64, 96, 128]) -@pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 1), - (1, 3), - (2, 1), - (511, 1), - (3, 513), - (64, 128), - (128, 128), - (256, 256), - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (307, 256), - (640, 128), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ], -) -def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype -): - device = "cuda" - # set seed - torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) - # batch_size = 40 - # nheads = 16 - batch_size = 9 if seqlen_q <= 2048 else 2 - nheads = 6 - # batch_size = 2 - # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: - dv_vals = [d] - for dv in dv_vals: - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4).detach().requires_grad_() - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - else: - qv_ref = None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] - qv = qv_ref.detach() if has_qv else None - query_padding_mask = generate_random_padding_mask( - seqlen_q, batch_size, device, mode="random", zero_lengths=False - ) - key_padding_mask = generate_random_padding_mask( - seqlen_k, batch_size, device, mode="random", zero_lengths=True - ) - - def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): - if add_unused: - another_mask = generate_random_padding_mask(max_seq_len, bs, device) - attn_mask = torch.logical_and(padding_mask, another_mask) - unused_mask = torch.logical_xor( - torch.logical_or(padding_mask, another_mask), attn_mask - ) - else: - attn_mask = padding_mask - unused_mask = None - return attn_mask, unused_mask - - query_padding_mask, query_unused_mask = _gen_unused_masks( - query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device - ) - key_padding_mask, key_unused_mask = _gen_unused_masks( - key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device - ) - - ( - q_unpad, - k_unpad, - v_unpad, - qv_unpad, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - qv, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - - if query_unused_mask is not None: - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 - - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - seqused_q=seqused_q, - seqused_k=seqused_k, - causal=causal, - qv=qv_unpad, - q_descale=q_descale, - k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - ) - out = output_pad_fn(out_unpad) - if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - - - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: - g_unpad = torch.randn_like(out_unpad) - do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) - # import flash_attn_3_cuda - # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( - # g_unpad, - # q_unpad, - # k_unpad, - # v_unpad, - # out_unpad, - # lse, - # None, - # None, - # None, - # cu_seqlens_q, - # cu_seqlens_k, - # None, None, - # max_seqlen_q, - # max_seqlen_k, - # d ** (-0.5), - # causal, - # window_size[0], window_size[1], - # softcap, - # deterministic, - # 0, # sm_margin - # ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) - dq = dq_pad_fn(dq_unpad) - dk = dk_pad_fn(dk_unpad) - dv = dk_pad_fn(dv_unpad) - if key_unused_mask is not None: - k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") - dk.masked_fill_(k_zero_masking, 0.0) - dv.masked_fill_(k_zero_masking, 0.0) - if query_unused_mask is not None: - dq.masked_fill_(q_zero_masking, 0.0) - # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") - # assert (softmax_d - do_o).abs().max().item() <= 1e-5 - # assert dq_accum.abs().max().item() == 0.0 - g = output_pad_fn(g_unpad) - - # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() - # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) - # P = torch.softmax(qk, -1) - # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) - # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) - # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) - # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - - - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) - dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - # breakpoint() - - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol - - -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) -@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) -# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) -# @pytest.mark.parametrize("has_rotary_seqlens", [False]) -@pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) -# @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) -# @pytest.mark.parametrize("page_size", [None]) -@pytest.mark.parametrize("has_leftpad", [False, True]) -# @pytest.mark.parametrize("has_leftpad", [False]) -@pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) -@pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) -# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [64]) -# @pytest.mark.parametrize("d", [192]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - # (1, 128 * 1024), - # (16, 128 * 1024), - (128, 128), - (256, 512), # To test appending KV with more than 1 block - (2048, 3577), # Enough tile to test persistent scheduler - ], -) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_kvcache( - seqlen_q, - seqlen_k, - d, - varlen_q, - has_batch_idx, - has_leftpad, - page_size, - rotary_fraction, - rotary_interleaved, - has_rotary_seqlens, - seqlen_new_eq_seqlen_q, - causal, - local, - new_kv, - mha_type, - dtype, -): - if page_size is not None and seqlen_k % page_size != 0: - pytest.skip() - if seqlen_q > seqlen_k and new_kv: - pytest.skip() - if not new_kv and rotary_fraction > 0.0: - pytest.skip() - if rotary_fraction == 0.0 and has_rotary_seqlens: - pytest.skip() - device = "cuda" - # set seed - torch.random.manual_seed(0) - batch_size = 5 - # batch_size = 1 - batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 6 - # nheads = 1 - # rotary_dim must be a multiple of 16, and must be <= d - rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) - assert nheads % nheads_k == 0 - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: - dv_vals = [d] - for dv in dv_vals: - has_qv = d == 64 and dv >= 256 - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if has_qv: - qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - else: - qv = None - if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input( - output_unpad, indices_q, batch_size, seqlen_q - ) - qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None - else: - query_padding_mask = None - q_unpad = q - qv_unpad = qv - cu_seqlens_q, max_seqlen_q = None, None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - cu_seqlens_k_new = None - key_new_padding_mask = None - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) - v_unpad, *rest = unpad_input(v, key_new_padding_mask) - else: - k_unpad, v_unpad = k, v - else: - k, v, k_unpad, v_unpad = None, None, None, None - if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - page_table = None - else: - ( - k_cache, - v_cache, - page_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - if not new_kv: - key_padding_mask = arange < cache_seqlens_expanded - else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new - key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if page_size is None else num_blocks * page_size, - rotary_dim // 2, - device=device, - ) - * 2 - * math.pi - ) - cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved - ) - else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=rotary_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, - ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved - ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens - ) - k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") - v_to_update = rearrange(v, "b s ... -> (b s) ...") - if varlen_q: - k_to_update = k_to_update[indices_k] - v_to_update = v_to_update[indices_k] - k_cache_ref[update_mask] = k_to_update - v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - qv=qv, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - qv=qv, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None - ) - q = q.to(dtype) - q_unpad = q_unpad.to(dtype) if varlen_q else None - k_cache = k_cache.to(dtype) - v_cache = v_cache.to(dtype) - k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None - v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None - k = k.to(dtype) if k is not None else None - v = v.to(dtype) if v is not None else None - k_unpad = k_unpad.to(dtype) if k_unpad is not None else None - v_unpad = v_unpad.to(dtype) if v_unpad is not None else None - qv = qv.to(dtype) if qv is not None else None - qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None - cos = cos.to(dtype) if cos is not None else None - sin = sin.to(dtype) if sin is not None else None - k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() - v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() - num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] - precompute_metadata_vals = [False, True] - for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): - if precompute_metadata: - scheduler_metadata = get_scheduler_metadata( - batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, - cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, - max_seqlen_k_new=seqlen_new, page_size=page_size, - causal=causal, window_size=window_size, - num_splits=num_splits - ) - else: - scheduler_metadata = None - # Repeat to test metadata reuse - for _ in range(1 if not precompute_metadata else 2): - if page_size is None: - k_cache.copy_(k_cache_saved) - v_cache.copy_(v_cache_saved) - else: - k_cache_paged.copy_(k_cache_saved) - v_cache_paged.copy_(v_cache_saved) - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - qv=qv if not varlen_q else qv_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - rotary_seqlens=rotary_seqlens, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - scheduler_metadata=scheduler_metadata, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) - else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) - # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() - if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) - else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() - - -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): - num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref - ).to(dtype).to(dtype_ref) - v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref - ).to(dtype).to(dtype_ref) - page_table = rearrange( - torch.randperm(num_blocks, dtype=torch.int32, device=device), - "(b nblocks) -> b nblocks", - b=batch_size, - ) - k_cache = rearrange( - k_cache_paged[page_table.flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - v_cache = rearrange( - v_cache_paged[page_table.flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks - - -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [False]) -@pytest.mark.parametrize('d', [128]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (64, 8192), - ], -) -def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): - device = "cuda" - torch.random.manual_seed(0) - batch_size = 2 - nheads = 16 - nheads_kv = 4 - # There was a bug where this would cause "unspecified launch failure" due to Cluster - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) - for _ in range(100): - flash_attn_func(q, k, v, causal=causal) - - -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [False]) -@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) -# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [80]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 239), - (239, 1), - (3, 799), - (799, 3), - (1024, 128), - (97, 97), - (128, 128), - (200, 200), - (256, 256), - (257, 257), - (384, 384), - (512, 512), - (768, 768), - (1024, 1024), - (2048, 2048), - ], -) -def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): - device = "cuda" - # set seed - torch.random.manual_seed(0) - # Simulate under memory load - dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) - batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger - nheads = 4 - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - torch.random.manual_seed(42) - out0, lse0 = flash_attn_func(q, k, v, causal=causal) - g = torch.randn_like(out0) - dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) - # Numerical error if we just do any arithmetic on dq - dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() - - for i in range(1000): - torch.random.manual_seed(42) - out, lse = flash_attn_func(q, k, v, causal=causal) - assert torch.equal(out, out0) - assert torch.equal(lse, lse0) - - dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_equal = torch.allclose(dq, dq0, atol=dq_atol) - if not dq_equal: - print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") - # breakpoint() - assert torch.equal(dv, dv0) - assert torch.equal(dk, dk0) - assert dq_equal - - -def attention_combine_ref(out_partial, lse_partial): - """ - out_partial: (num_splits, batch_size, seqlen, nheads, d) - lse_partial: (num_splits, batch_size, nheads, seqlen) - """ - lse = torch.logsumexp(lse_partial, dim=0) - scale = torch.exp(lse_partial - lse) - scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) - out = (scale.unsqueeze(-1) * out_partial).sum(0) - return out, lse - - -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float32]) -# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) -# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) -# @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) -# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) -# @pytest.mark.parametrize("num_splits", [128]) -def test_flash_attn_combine(num_splits, seqlen, d, dtype): - if DISABLE_SPLIT: - pytest.skip() - device = "cuda" - # set seed - torch.random.manual_seed(1) - batch_size = 5 - nheads = 16 - # batch_size = 1 - # nheads = 1 - out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor - lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor - # To test short-circuiting based on num_splits - lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") - out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) - out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) - out_pt = out_ref.to(dtype) - - print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) - multiple = 2 - assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) - - # from flash_attn.utils.benchmark import pytorch_profiler - # # pytorch_profiler(torch.sum, lse_partial) - # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) - # pytorch_profiler(torch.sum, out_partial) diff --git a/torch-ext/pytorch_shim.h b/torch-ext/pytorch_shim.h deleted file mode 100644 index f27e150c827e457a93a45a4d2784a23871a73b6b..0000000000000000000000000000000000000000 --- a/torch-ext/pytorch_shim.h +++ /dev/null @@ -1,105 +0,0 @@ -#pragma once - -#include - -/** - * Unforunately, the type signatures of the flash_attn ops are not compatible - * with the PyTorch library bindings. To get around that we use - * `make_pytorch_shim` which creates a lambda that exponses the API using - * PyTorch compatible types to the types, then converts them to the types - * expected by the flash_attn ops. This shims allows us to make minimal changes - * to `flash_api.cpp` making it easier to synchronize with upstream changes. - * - * The `pytorch_library_compatible_type` struct is used to map from the - * flash_attn ops types to a PyTorch library compatible one. The main issues is - * that the following types are not support by PyTorch libary bindings: - * - `int` - * - `float` - * - `std::optional &` - * - `std::optional &` - * So we convert them to (respectively): - * - `int64_t` - * - `double` - * - `const std::optional&` - * - `const std::optional&` - */ - -template -struct pytorch_library_compatible_type { - using type = T; - static T convert_from_type(T arg) { return arg; } -}; - -template -using pytorch_library_compatible_type_t = \ - typename pytorch_library_compatible_type::type; - -template -T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t arg) - { return pytorch_library_compatible_type::convert_from_type(arg); } - -// Map `std::optional &` -> `const std::optional&` -// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate -// the optional container) -template -struct pytorch_library_compatible_type &> { - using type = const std::optional&; - static std::optional& convert_from_type(const std::optional &arg) { - return const_cast&>(arg); - } -}; - -// Map `std::optional` -> -// `std::optional>` -// (NOTE: tested for `std::optional` -> `std::optional`) -template -struct pytorch_library_compatible_type> { - using type = std::optional>; - static std::optional> convert_from_type(std::optional arg) { - return arg; - } -}; - -// Map `std::optional&` -> `const std::optional&` -template<> -struct pytorch_library_compatible_type &> { - using type = const std::optional&; - static std::optional& convert_from_type( - const std::optional &arg) { - return const_cast&>( - reinterpret_cast&>(arg)); - } -}; - -// Map `int` -> `int64_t` -template<> struct pytorch_library_compatible_type { - using type = int64_t; - static int convert_from_type(int64_t arg) { - TORCH_CHECK(arg <= std::numeric_limits::max(), - "int64_t value is too large to be converted to int"); - TORCH_CHECK(arg >= std::numeric_limits::min(), - "int64_t value is too small to be converted to int"); - return arg; - } -}; - -// Map `float` -> `double` -template<> struct pytorch_library_compatible_type { - using type = double; - static float convert_from_type(double arg) { - TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), - "double value is too large to be converted to float"); - return arg; - } -}; - -// -// Shim Utils -// - -template -auto make_pytorch_shim(Ret(*fun)(Args... args)){ - return [fun](pytorch_library_compatible_type_t... args) { - return fun(convert_from_pytorch_compatible_type(args)...); - }; -} diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp deleted file mode 100644 index 6c1cb71993db11c1cf5a0128c321fd20a281fe5d..0000000000000000000000000000000000000000 --- a/torch-ext/torch_binding.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include - -#include "pytorch_shim.h" -#include "registration.h" -#include "torch_binding.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - ops.def("fwd(Tensor! q," - " Tensor k," - " Tensor v," - " Tensor? k_new," - " Tensor? v_new," - " Tensor? q_v," - " Tensor!? out," - " Tensor? cu_seqlens_q," - " Tensor? cu_seqlens_k," - " Tensor? cu_seqlens_k_new," - " Tensor? seqused_q," - " Tensor? seqused_k," - " int? max_seqlen_q," - " int? max_seqlen_k," - " Tensor? page_table," - " Tensor? kv_batch_idx," - " Tensor? leftpad_k," - " Tensor? rotary_cos," - " Tensor? rotary_sin," - " Tensor? seqlens_rotary," - " Tensor? q_descale," - " Tensor? k_descale," - " Tensor? v_descale," - " float softmax_scale," - " bool is_causal," - " int window_size_left," - " int window_size_right," - " float softcap," - " bool is_rotary_interleaved," - " Tensor? scheduler_metadata," - " int num_splits," - " bool? pack_gqa," - " int sm_margin," - " Tensor? s_aux_) -> Tensor[]"); - ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); - - ops.def("bwd(Tensor dout," - " Tensor q," - " Tensor k," - " Tensor v," - " Tensor out," - " Tensor softmax_lse," - " Tensor!? dq," - " Tensor!? dk," - " Tensor!? dv," - " Tensor? cu_seqlens_q," - " Tensor? cu_seqlens_k," - " Tensor? seqused_q," - " Tensor? seqused_k," - " int? max_seqlen_q," - " int? max_seqlen_k," - " float softmax_scale," - " bool is_causal," - " int window_size_left," - " int window_size_right," - " float softcap," - " bool deterministic," - " int sm_margin) -> Tensor[]"); - ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd)); - - ops.def("fwd_combine(Tensor out_partial," - " Tensor lse_partial," - " Tensor!? out," - " ScalarType? out_dtype) -> Tensor[]"); - ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine)); - - ops.def("get_scheduler_metadata(" - " int batch_size," - " int max_seqlen_q," - " int max_seqlen_k," - " int num_heads," - " int num_heads_k," - " int headdim," - " int headdim_v," - " ScalarType qkv_dtype," - " Tensor seqused_k," - " Tensor? cu_seqlens_q," - " Tensor? cu_seqlens_k," - " Tensor? cu_seqlens_k_new," - " Tensor? seqused_q," - " Tensor? leftpad_k," - " int? page_size," - " int max_seqlen_k_new," // 0 means we're not appending new KV - " bool is_causal," - " int window_size_left," - " int window_size_right," - " bool has_softcap," - " int num_splits," - " bool? pack_gqa," - " int sm_margin) -> Tensor"); - ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h deleted file mode 100644 index 5d003aa0afdb3086b020c99bbd7c546c80b7a75b..0000000000000000000000000000000000000000 --- a/torch-ext/torch_binding.h +++ /dev/null @@ -1,103 +0,0 @@ -#pragma once - -#include -#include - -#include - -std::vector -mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. - std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional &page_table_, // (b_k, max_num_pages_per_seq) - std::optional &kv_batch_idx_, // b. indices to index into the KV cache - std::optional &leftpad_k_, // b - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &seqlens_rotary_, // b - std::optional &q_descale_, // (b, h_k), not (b, h) - std::optional &k_descale_, // (b, h_k) - std::optional &v_descale_, // (b, h_k) - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional &scheduler_metadata_, // (b + 1) - int num_splits, - std::optional pack_gqa_, - int const sm_margin, - std::optional &s_aux_ // (h) - ); - -std::vector mha_bwd( - const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - std::optional &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - std::optional max_seqlen_k_, - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const deterministic, - int const sm_margin); - -std::vector -mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size - const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads - std::optional out_, // batch_size x seqlen x num_heads x head_size - std::optional out_dtype_ - ); - -at::Tensor -mha_fwd_get_scheduler_metadata( - int batch_size, - int max_seqlen_q, - int max_seqlen_k, - int num_heads, - int num_heads_k, - int headdim, - int headdim_v, - at::ScalarType qkv_dtype, - const at::Tensor &seqused_k, // b - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &leftpad_k_, // b - std::optional page_size, - int max_seqlen_k_new, // 0 means we're not appending new KV - bool is_causal, - int window_size_left, - int window_size_right, - bool has_softcap, - int num_splits, - std::optional pack_gqa_, - int const sm_margin - ); - diff --git a/torch-ext/vllm_flash_attn3/__init__.py b/torch-ext/vllm_flash_attn3/__init__.py deleted file mode 100644 index 3d12edd447e84d6e65c4fa9ab91ae3c3192884c1..0000000000000000000000000000000000000000 --- a/torch-ext/vllm_flash_attn3/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .flash_attn_interface import ( - flash_attn_combine, - flash_attn_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_func, - flash_attn_with_kvcache, - get_scheduler_metadata, -) - -__all__ = [ - "flash_attn_combine", - "flash_attn_func", - "flash_attn_qkvpacked_func", - "flash_attn_varlen_func", - "flash_attn_with_kvcache", - "get_scheduler_metadata", -] diff --git a/torch-ext/vllm_flash_attn3/flash_attn_interface.py b/torch-ext/vllm_flash_attn3/flash_attn_interface.py deleted file mode 100644 index 7f29690eb05c7502141e9c57ed3273185e93e285..0000000000000000000000000000000000000000 --- a/torch-ext/vllm_flash_attn3/flash_attn_interface.py +++ /dev/null @@ -1,815 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Optional, Union - -import torch -import torch.nn as nn - -# isort: off -# We need to import the CUDA kernels after importing torch -from ._ops import ops - -# isort: on - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - - -def _flash_attn_forward( - q, - k, - v, - k_new, - v_new, - qv, - out, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - page_table, - kv_batch_idx, - leftpad_k, - rotary_cos, - rotary_sin, - seqlens_rotary, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - s_aux=None): - q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] - v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v - cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ - maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new) - ] - seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)] - page_table, kv_batch_idx, leftpad_k = [ - maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k) - ] - rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] - seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = ops.fwd( - q, - k, - v, - k_new, - v_new, - qv, - out, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - page_table, - kv_batch_idx, - leftpad_k, - rotary_cos, - rotary_sin, - seqlens_rotary, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - rotary_interleaved, - scheduler_metadata, - num_splits, - pack_gqa, - sm_margin, - s_aux - ) - return out, softmax_lse, *rest - - -def _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - sequed_q, - sequed_k, - max_seqlen_q, - max_seqlen_k, - dq, - dk, - dv, - softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - sm_margin=0, -): - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = ops.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - sequed_q, - sequed_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - deterministic, - sm_margin, - ) - return dq, dk, dv, softmax_d - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - softmax_scale, - causal, - q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - num_heads_q=None, - ): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - if qkv.dim() == 5: - assert qkv.shape[-3] == 3 - q, k, v = qkv.unbind(dim=-3) - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert num_heads_k * 2 + num_heads_q == qkv.shape[2] - q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) - out, softmax_lse, *rest = _flash_attn_forward( - q, - k, - v, - None, None, # k_new, v_new - None, # qv - None, # out - None, None, None, # cu_seqlens_q/k/k_new - None, None, # seqused_q/k - None, None, # max_seqlen_q/k - None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, None, # rotary_cos/sin, seqlens_rotary - q_descale, k_descale, v_descale, - softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - ) - # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.ndim = qkv.dim() - # return out, softmax_lse - return out - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - if ctx.ndim == 5: - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - dq, dk, dv = dqkv.unbind(dim=-3) - else: - num_heads_q = q.shape[2] - num_heads_k = k.shape[2] - qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - None, None, # cu_seqlens_q, cu_seqlens_k, - None, None, # sequed_q, sequed_k, - None, None, # max_seqlen_q, max_seqlen_k, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.softcap, - ctx.deterministic, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnFunc(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - softmax_scale, - causal, - qv=None, - q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), - softcap=0.0, - num_splits=1, - pack_gqa=None, - deterministic=False, - sm_margin=0, - s_aux=None, - ): - if softmax_scale is None: - softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) - # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( - out, softmax_lse, *rest = _flash_attn_forward( - q, - k, - v, - None, None, # k_new, v_new - qv, # qv - None, # out - None, None, None, # cu_seqlens_q/k/k_new - None, None, # seqused_q/k - None, None, # max_seqlen_q/k - None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, None, # rotary_cos/sin, seqlens_rotary - q_descale, k_descale, v_descale, - softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - sm_margin=sm_margin, - s_aux=s_aux, - ) - # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.sm_margin = sm_margin - return out, softmax_lse - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - None, None, # cu_seqlens_q, cu_seqlens_k, - None, None, # sequed_q, sequed_k, - None, None, # max_seqlen_q, max_seqlen_k, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.softcap, - ctx.deterministic, - ctx.sm_margin, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnVarlenFunc(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - qv=None, - q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), - softcap=0.0, - num_splits=1, - pack_gqa=None, - deterministic=False, - sm_margin=0, - s_aux=None, - ): - if softmax_scale is None: - softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) - # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( - out, softmax_lse, *rest = _flash_attn_forward( - q, - k, - v, - None, None, # k_new, v_new - qv, # qv - None, # out - cu_seqlens_q, - cu_seqlens_k, - None, # cu_seqlens_k_new - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, None, # rotary_cos/sin, seqlens_rotary - q_descale, k_descale, v_descale, - softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - sm_margin=sm_margin, - s_aux=s_aux, - ) - # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.sm_margin = sm_margin - return out, softmax_lse - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.softcap, - ctx.deterministic, - ctx.sm_margin, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_func( - qkv, - softmax_scale=None, - causal=False, - q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - num_heads_q=None, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_kvpacked_func and flash_attn_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to - the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnQKVPackedFunc.apply( - qkv, - softmax_scale, - causal, - q_descale, k_descale, v_descale, - window_size, - softcap, - deterministic, - num_heads_q, - ) - - -def flash_attn_func( - q, - k, - v, - softmax_scale=None, - causal=False, - qv=None, - q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), - softcap=0.0, - num_splits=1, - pack_gqa=None, - deterministic=False, - sm_margin=0, - s_aux=None, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - return FlashAttnFunc.apply( - q, - k, - v, - softmax_scale, - causal, - qv, - q_descale, k_descale, v_descale, - window_size, - softcap, - num_splits, - pack_gqa, - deterministic, - sm_margin, - s_aux, - ) - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - seqused_q=None, - seqused_k=None, - softmax_scale=None, - causal=False, - qv=None, - q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), - softcap=0.0, - num_splits=1, - pack_gqa=None, - deterministic=False, - sm_margin=0, - s_aux=None, -): - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - qv, - q_descale, k_descale, v_descale, - window_size, - softcap, - num_splits, - pack_gqa, - deterministic, - sm_margin, - s_aux, - ) - - -def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): - return ops.fwd_combine(out_partial, lse_partial, out, out_dtype) - - -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - qv=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - rotary_seqlens: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=0, # Can be tuned for speed - pack_gqa=None, # Can be tuned for speed - sm_margin=0, # Can be tuned if some SMs are used for communication - return_softmax_lse=False, - s_aux=None, -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, - or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. - qv [optional]: (batch_size, seqlen, nheads, headdim_v) - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. - page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - if softmax_scale is None: - softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) - cache_seqlens = maybe_contiguous(cache_seqlens) - out, softmax_lse, *rest = _flash_attn_forward( - q, - k_cache, - v_cache, - k, - v, - qv, - None, # out - cu_seqlens_q, - None, # cu_seqlens_k - cu_seqlens_k_new, - None, # seqused_q - cache_seqlens, - max_seqlen_q, - None, # max_seqlen_k - page_table, - cache_batch_idx, - cache_leftpad, - rotary_cos, - rotary_sin, - rotary_seqlens, - q_descale, k_descale, v_descale, - softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - rotary_interleaved=rotary_interleaved, - scheduler_metadata=scheduler_metadata, - num_splits=num_splits, - pack_gqa=pack_gqa, - sm_margin=sm_margin, - s_aux=s_aux, - ) - # return (out, softmax_lse) if return_softmax_lse else out - return (out, softmax_lse, *rest) if return_softmax_lse else out - - -def get_scheduler_metadata( - batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, - cache_seqlens: torch.Tensor, - qkv_dtype=torch.bfloat16, - headdim_v=None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, - max_seqlen_k_new=0, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - has_softcap=False, - num_splits=0, # Can be tuned for speed - pack_gqa=None, # Can be tuned for speed - sm_margin=0, # Can be tuned if some SMs are used for communication -): - cache_seqlens = maybe_contiguous(cache_seqlens) - if headdim_v is None: - headdim_v = headdim - scheduler_metadata = ops.get_scheduler_metadata( - batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, - qkv_dtype, - cache_seqlens, - cu_seqlens_q, - None, # cu_seqlens_k - cu_seqlens_k_new, - None, # seqused_q - cache_leftpad, - page_size, - max_seqlen_k_new, - causal, - window_size[0], window_size[1], - has_softcap, - num_splits, - pack_gqa, - sm_margin, - ) - return scheduler_metadata