diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..3c6140427612081944ce9313e2b367dd20a9049b --- /dev/null +++ b/build.toml @@ -0,0 +1,593 @@ +[general] +name = "flash_attn3" +universal = false +cuda-minver = "12.4" +cuda-maxver = "12.4" + +[torch] +src = [ + "torch-ext/pytorch_shim.h", + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + +[kernel.flash_attn] +backend = "cuda" +cuda-capabilities = ["8.0", "9.0a"] +cuda-flags = [ + "-O3", + "-std=c++17", + "--ftemplate-backtrace-limit=0", # To debug template code + "--use_fast_math", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", + "-DCUTLASS_ENABLE_GDC_FOR_SM90", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-DNDEBUG", +] + +src = [ + "flash-attn/cuda_check.h", + "flash-attn/flash_api.cpp", + "flash-attn/flash_fwd_combine.cu", + "flash-attn/flash_fwd_combine_kernel.h", + "flash-attn/flash_fwd_combine_launch_template.h", + "flash-attn/flash.h", + "flash-attn/flash_prepare_scheduler.cu", + "flash-attn/heuristics.h", + "flash-attn/seqlen.h", + "flash-attn/static_switch.h", + "flash-attn/tile_size.h", + "flash-attn/utils.h", +] +depends = ["torch", "cutlass_3_9"] + +[kernel.flash_attn_sm80] +backend = "cuda" +cuda-capabilities = ["8.0", "9.0a"] +cuda-flags = [ + "-O3", + "-std=c++17", + "--ftemplate-backtrace-limit=0", # To debug template code + "--use_fast_math", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", + "-DCUTLASS_ENABLE_GDC_FOR_SM90", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-DNDEBUG", +] +src = [ + "flash-attn/block.h", + "flash-attn/copy_sm90_bulk_reduce.hpp", + "flash-attn/epilogue_bwd.hpp", + "flash-attn/epilogue_fwd.hpp", + "flash-attn/flash.h", + "flash-attn/flash_bwd_kernel_sm80.h", + "flash-attn/flash_bwd_kernel_sm90.h", + "flash-attn/flash_bwd_launch_template.h", + "flash-attn/flash_bwd_postprocess_kernel.h", + "flash-attn/flash_bwd_preprocess_kernel.h", + "flash-attn/flash_fwd_launch_template.h", + "flash-attn/flash_fwd_kernel_sm80.h", + "flash-attn/flash_fwd_kernel_sm90.h", + "flash-attn/heuristics.h", + "flash-attn/mainloop_bwd_sm80.hpp", + "flash-attn/mainloop_fwd_sm80.hpp", + "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp", + "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp", + "flash-attn/mask.h", + "flash-attn/named_barrier.hpp", + "flash-attn/pack_gqa.h", + "flash-attn/paged_kv.h", + "flash-attn/rotary.h", + "flash-attn/sm90_pipeline_no_cluster.hpp", + "flash-attn/softmax.h", + "flash-attn/tile_size.h", + "flash-attn/tile_scheduler.hpp", + + "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu", + "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu" +] +include = ["flash-attn"] +depends = ["torch", "cutlass_3_9"] + +[kernel.flash_attn_sm90] +backend = "cuda" +cuda-capabilities = ["8.0", "9.0a"] +cuda-flags = [ + "-O3", + "-std=c++17", + "--ftemplate-backtrace-limit=0", # To debug template code + "--use_fast_math", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", + "-DCUTLASS_ENABLE_GDC_FOR_SM90", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-DNDEBUG", +] +src = [ + "flash-attn/block.h", + "flash-attn/copy_sm90_bulk_reduce.hpp", + "flash-attn/epilogue_bwd.hpp", + "flash-attn/epilogue_fwd.hpp", + "flash-attn/flash.h", + "flash-attn/flash_bwd_kernel_sm80.h", + "flash-attn/flash_bwd_kernel_sm90.h", + "flash-attn/flash_bwd_launch_template.h", + "flash-attn/flash_bwd_postprocess_kernel.h", + "flash-attn/flash_bwd_preprocess_kernel.h", + "flash-attn/flash_fwd_launch_template.h", + "flash-attn/flash_fwd_kernel_sm80.h", + "flash-attn/flash_fwd_kernel_sm90.h", + "flash-attn/heuristics.h", + "flash-attn/mainloop_bwd_sm80.hpp", + "flash-attn/mainloop_fwd_sm80.hpp", + "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp", + "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp", + "flash-attn/mask.h", + "flash-attn/named_barrier.hpp", + "flash-attn/pack_gqa.h", + "flash-attn/paged_kv.h", + "flash-attn/rotary.h", + "flash-attn/sm90_pipeline_no_cluster.hpp", + "flash-attn/softmax.h", + "flash-attn/tile_size.h", + "flash-attn/tile_scheduler.hpp", + + "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu", + "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu", + "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu", +] +include = ["flash-attn"] +depends = ["torch", "cutlass_3_9"] + +# [kernel.flash_attn_sm100] +# backend = "cuda" +# cuda-capabilities = ["8.0", "9.0a", "10.0"] +# cuda-flags = [ +# "-O3", +# "-std=c++17", +# "--ftemplate-backtrace-limit=0", # To debug template code +# "--use_fast_math", +# "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", +# "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1", +# "-DCUTLASS_ENABLE_GDC_FOR_SM90", +# "--expt-relaxed-constexpr", +# "--expt-extended-lambda", +# "--use_fast_math", +# "-DNDEBUG", +# ] +# src = [ +# "flash-attn/block.h", +# "flash-attn/copy_sm90_bulk_reduce.hpp", +# "flash-attn/epilogue_bwd.hpp", +# "flash-attn/epilogue_fwd.hpp", +# "flash-attn/flash.h", +# "flash-attn/flash_bwd_kernel_sm80.h", +# "flash-attn/flash_bwd_kernel_sm90.h", +# "flash-attn/flash_bwd_launch_template.h", +# "flash-attn/flash_bwd_postprocess_kernel.h", +# "flash-attn/flash_bwd_preprocess_kernel.h", +# "flash-attn/flash_fwd_launch_template.h", +# "flash-attn/flash_fwd_kernel_sm80.h", +# "flash-attn/flash_fwd_kernel_sm90.h", +# "flash-attn/heuristics.h", +# "flash-attn/mainloop_bwd_sm80.hpp", +# "flash-attn/mainloop_fwd_sm80.hpp", +# "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp", +# "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp", +# "flash-attn/mask.h", +# "flash-attn/named_barrier.hpp", +# "flash-attn/pack_gqa.h", +# "flash-attn/paged_kv.h", +# "flash-attn/rotary.h", +# "flash-attn/sm90_pipeline_no_cluster.hpp", +# "flash-attn/softmax.h", +# "flash-attn/tile_size.h", +# "flash-attn/tile_scheduler.hpp", +# +# "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu", +# ] +# include = ["flash-attn"] +# depends = ["torch", "cutlass_3_9"] diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..a5932d78d8e0ddd3bf0173322caacab55f3790d8 --- /dev/null +++ b/flake.lock @@ -0,0 +1,168 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1750234878, + "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "c7132f90763d756da3e77da62e01be0a4546dc57", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1750275112, + "narHash": "sha256-gqAxmLLt0tYvuRYumOZHQgryMeEFdt6j3nEC8B5rT14=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "1b63210b2a1fc3cda2e3a579e7aa8f8c8532626f", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1747820358, + "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", + "owner": "danieldk", + "repo": "nixpkgs", + "rev": "d3c1681180717528068082103bf323147de6ab0b", + "type": "github" + }, + "original": { + "owner": "danieldk", + "ref": "cudatoolkit-12.9-kernel-builder", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..b895427963bea8c91d8054b4478074bc8ba838e9 --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for Hopper Flash Attention kernel"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/flash-attn/block.h b/flash-attn/block.h new file mode 100644 index 0000000000000000000000000000000000000000..e69eede49ad4f544ebfbddeb3fd908230bade7c2 --- /dev/null +++ b/flash-attn/block.h @@ -0,0 +1,139 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +template +struct BlockMN { + + static + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + int const seqlen_k = seqlen_info.seqlen_k; + int const seqlen_q = seqlen_info.seqlen_q; + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal || Is_local) { + int m_idx_max = (m_block + 1) * kBlockM; + // TODO: check off-by-1 error + if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } + int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); + } + n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN)); + } + int n_block_min = 0; + if constexpr (Is_local) { + int m_idx_min = m_block * kBlockM; + if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } + int const n_idx = m_idx_min + seqlen_k - seqlen_q; + int n_idx_left = n_idx - window_size_left; + if (attention_chunk_divmod.divisor > 0) { + n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); + } + n_block_min = std::max(int(0), n_idx_left / kBlockN); + } + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + if constexpr (Split) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + int split_idx_actual = split_idx & 0x0000FFFF; + int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); + n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; + n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } + } + // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + return {n_block_min, n_block_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_n_block_k_new_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + auto [n_block_min, n_block_max] = get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, num_splits, + window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod); + int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); + int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); + int const n_block_new_min = idx_k_new_min / kBlockN; + int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; + // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} + return {n_block_new_min, n_block_new_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_m_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const n_block, int const bidb, + int const window_size_left, int const window_size_right, int const sink_token_length) { + // TODO: support attention_chunk + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + if constexpr (Is_local) { + if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM)); + } + } + int m_block_min = 0; + if constexpr (Is_causal || Is_local) { + m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM); + } + return {m_block_min, m_block_max}; + } + + // If we have separate iterations with causal or local masking at the start, where do we stop + static + CUTLASS_DEVICE + int get_n_block_min_causal_local_mask( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const n_block_min, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM); + int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); + } + return std::max(n_block_min, n_idx_right / kBlockN); + } + + // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations + static + CUTLASS_DEVICE + int get_n_block_min_before_local_mask( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const n_block_min, int const window_size_left, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); + } + return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN)); + } + +}; + +} // namespace flash diff --git a/flash-attn/copy_sm90_bulk_reduce.hpp b/flash-attn/copy_sm90_bulk_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8556fae66d6d05af6c82bdc49735fdfece2a978b --- /dev/null +++ b/flash-attn/copy_sm90_bulk_reduce.hpp @@ -0,0 +1,49 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_REDUCE_ADD +{ + CUTE_HOST_DEVICE static void + copy(float const* smem_ptr, + float * gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + CUTE_HOST_DEVICE static void + copy(float const* smem_ptr, + float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/flash-attn/cuda_check.h b/flash-attn/cuda_check.h new file mode 100644 index 0000000000000000000000000000000000000000..b5e63aef79d22f9afdf83da05dbad0f2b8397ac9 --- /dev/null +++ b/flash-attn/cuda_check.h @@ -0,0 +1,19 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) diff --git a/flash-attn/epilogue_bwd.hpp b/flash-attn/epilogue_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d9b5f4f596813a6c3e8561a96c9e4acfd904794 --- /dev/null +++ b/flash-attn/epilogue_bwd.hpp @@ -0,0 +1,533 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/barrier.h" +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "seqlen.h" +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveEpilogueBwd { + + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads); + static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int(TileShape_MNK{})) / AtomLayoutKdKV>>()); + using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutdKVtTMA = + decltype(cute::composition(SmemLayoutdKVTMA{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); + + // If we don't use TMA + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16); + static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); + using SmemLayoutAtomdKVSTG = + decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + + using SmemLayoutAtomdKV = std::conditional_t; + using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutdKVt = + decltype(cute::composition(SmemLayoutdKV{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); + + using SmemCopyAtomdKV = Copy_Atom< + std::conditional_t< + ArchTag::kMinComputeCapability >= 90, + std::conditional_t, + AutoVectorizingCopyWithAssumedAlignment<128> + >, + Element>; + + static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128; + static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment"); + + struct TensorStorage : cute::aligned_struct { + cute::array_aligned, SmemAlignmentdKV> smem_dk; + cute::array_aligned, SmemAlignmentdKV> smem_dv; + }; + + using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) + using StridedKV = cute::Stride; + + using TMA_dKV = std::conditional_t< + Use_TMA, + decltype(make_tma_copy( + GmemTiledCopydKVTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedKV{}, StridedKV{}), + SmemLayoutdKVTMA{}, + select<1, 2>(TileShape_MNK{}), + _1{})), // no mcast for dKV + std::nullptr_t + >; + + // Host side kernel arguments + struct Arguments { + Element* ptr_dK; + ShapedKV const shape_dK; + StridedKV const stride_dK; + Element* ptr_dV; + ShapedKV const shape_dV; + StridedKV const stride_dV; + int const num_heads_q; + int* dk_semaphore; + int* dv_semaphore; + int const* cu_seqlens; + int const* seqused; + }; + + // Device side kernel params + struct Params { + Element* ptr_dK; + ShapedKV const shape_dK; + StridedKV const stride_dK; + Element* ptr_dV; + ShapedKV const shape_dV; + StridedKV const stride_dV; + TMA_dKV tma_store_dK, tma_store_dV; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); + Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); + TMA_dKV tma_store_dK = [&] { + if constexpr (Use_TMA) { + return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV + } else { + return nullptr; + } + }(); + TMA_dKV tma_store_dV = [&] { + if constexpr (Use_TMA) { + return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV + } else { + return nullptr; + } + }(); + return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV, + tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA) { + cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO const& tdKrdK, + FrgTensorO const& tdVrdV, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [n_block, bidh, bidb] = block_coord; + Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{})); + Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{})); + Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{})); + Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{})); + auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx); + + Tensor tdVrdV_out = make_tensor_like(tdVrdV); + flash::convert_type_out(tdVrdV, tdVrdV_out); + Tensor tdKrdK_out = make_tensor_like(tdKrdK); + flash::convert_type_out(tdKrdK, tdKrdK_out); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); } + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading K and V + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + if constexpr (Use_TMA) { + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); + Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV); + Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); + auto block_tma_dV = params.tma_store_dV.get_slice(_0{}); + Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) + Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) + Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) + Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_dV, tdVsdV, tdVgdV); + cute::copy(params.tma_store_dK, tdKsdK, tdKgdK); + tma_store_arrive(); + } + } + tma_store_wait<0>(); + // // Tell warp 0 that smem_k and smem_v are ready + // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + + } else { + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + + GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); + Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) + Tensor tdKVrdV = make_fragment_like(tdKVgdV); + Tensor tdKVrdK = make_fragment_like(tdKVgdK); + Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + // Need to check OOB when reading from smem if kBlockN isn't evenly tiled + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; + flash::copy( + gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN); + flash::copy( + gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN); + // // Tell warp 0 that smem_k and smem_v are ready + // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v + // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + // Construct identity layout for gdKV + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + ); + flash::copy( + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + ); + } + } + + CUTLASS_DEVICE void + store_tail() { + // if constexpr (Use_TMA) { tma_store_wait<0>(); } + } + + // Write 0 to dK and dV + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + auto [n_block, bidh, bidb] = block_coord; + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + + GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); + Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKVrdKV = make_fragment_like(tdKVgdK); + clear(tdKVrdKV); + // Construct identity layout for gdKV + Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + #pragma unroll + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN + ); + flash::copy( + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN + ); + } + +}; + +template +struct CollectiveEpilogueBwdGQA { + + using TileShape_MNK = TileShape_MNK_; + using Element = ElementAccum; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp"); + static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup; + // Thread layout, 256 or 384 threads per row + // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ + using R2SLayoutAtomdKVaccum = Layout, Int>>; + using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdKVaccum{}, + Layout>{})); // Val layout, 4 vals per store + // For Sm80 + using R2GLayoutAtomdKVaccum = Layout>>; + using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2GLayoutAtomdKVaccum{}, + Layout>{})); // Val layout, 1 vals per store + + using SmemLayoutdKVaccum = Layout, Int>>; + using SmemLayoutdKVaccumFlat = Layout>>; + + // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we + // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue. + static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256); + struct TensorStorageTMA : cute::aligned_struct { + cute::array_aligned, SmemAlignment> smem_dkv; + }; + struct TensorStorageSTG { + cute::array smem_dkv; + }; + using TensorStorage = std::conditional_t; + + using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) + using StridedKV = cute::Stride<_1, int64_t, int64_t>; + + // Host side kernel arguments + struct Arguments { + ElementAccum* ptr_dKaccum; + ShapedKV const shape_dKaccum; + StridedKV const stride_dKaccum; + ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; + StridedKV const stride_dVaccum; + int num_heads_q; + int* dk_semaphore; + int* dv_semaphore; + int const* cu_seqlens; + int const* seqused; + }; + + // Device side kernel params + struct Params { + ElementAccum* ptr_dKaccum; + ShapedKV const shape_dKaccum; + StridedKV const stride_dKaccum; + ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; + StridedKV const stride_dVaccum; + cutlass::FastDivmod qhead_per_khead_divmod; + int* dk_semaphore; + int* dv_semaphore; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + if constexpr (Deterministic) { + assert(args.dk_semaphore != nullptr); + assert(args.dv_semaphore != nullptr); + } + return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, + cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), + args.dk_semaphore, args.dv_semaphore, + args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO const& tdKrdK, + FrgTensorO const& tdVrdV, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [n_block, bidh, bidb] = block_coord; + int bidh_idx_in_group; + int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); + Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{}); + Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{}); + static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum); + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) + + R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; + auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); + Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV); + + // Only used if !Use_TMA + R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum; + auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx); + + // Make sure all WGs have finished reading K and V, otherwise we get racy dQ + // because smem_q could be changed. + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if constexpr (Use_TMA) { + Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); + } + + // int const num_batch = params.num_batch; + int const num_batch = get<2>(params.shape_dKaccum); + int const num_head_kv = get<1>(params.shape_dKaccum); + int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; + using Barrier = cutlass::GenericBarrier; + + // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} + + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); + } + // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);} + if constexpr (Use_TMA) { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (thread_idx == 0) { + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + tma_store_wait<0>(); + } + } else { + Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV); + Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum); + static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic))); + #pragma unroll + for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); } + } + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); + } + + if constexpr (Use_TMA) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum); + } + lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv; + // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} + + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); + } + // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);} + if constexpr (Use_TMA) { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (thread_idx == 0) { + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + tma_store_wait<0>(); + } + } else { + Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK); + Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum); + static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic))); + #pragma unroll + for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); } + } + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); + } + // // Tell warp 0 that smem_k and smem_v are ready + // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + } + + CUTLASS_DEVICE void + store_tail() { + } + + // Write 0 to dK and dV + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + // Don't need to do anything since dKaccum and dVaccum are already zero-initialized + } + +}; + +} // namespace flash diff --git a/flash-attn/epilogue_fwd.hpp b/flash-attn/epilogue_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..69102e8c4e6e9fc144fd0fcf3d624890781afb84 --- /dev/null +++ b/flash-attn/epilogue_fwd.hpp @@ -0,0 +1,484 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include // For FastDivMod +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" + +#include "seqlen.h" +#include "named_barrier.hpp" +#include "pack_gqa.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveEpilogueFwd { + + using TileShape_MNK_PV = TileShape_MNK_PV_; + using ClusterShape = ClusterShape_; + using Element = Element_; + using ElementPartial = float; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool Use_smem = !(Split && !Varlen); + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; + + static_assert(ArchTag::kMinComputeCapability >= 80); + static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); + static_assert(sizeof(Element) <= 2); + + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + + static constexpr bool LargeHeadDimV = kHeadDimV > 256; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements + // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times + // we need to call divmod. + static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; + // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); + static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); + using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) + using StrideO = cute::Stride; + using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; + using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; + // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) + using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; + using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + + using CopyOpR2S = std::conditional_t< + ArchTag::kMinComputeCapability >= 90, + // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using SmemCopyAtomO = Copy_Atom; + + // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{}); + // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment"); + // struct TensorStorage : cute::aligned_struct { + // cute::array_aligned : 0, SmemAlignmentO> smem_o; + // }; + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned : 0> smem_o; + }; + + using TMA_O = std::conditional_t< + Use_TMA_O, + decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), + SmemLayoutOTMA{}, + select<0, 1>(TileShape_MNK_PV{}), + _1{})), // no mcast for O + std::nullptr_t + >; + + // Host side kernel arguments + struct Arguments { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + float* ptr_LSE; + StrideLSE const stride_LSE; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + int32_t const nheads_kv; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Device side kernel params + struct Params { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + ShapeOPacked const shape_O_packed; + StrideOPacked const stride_O_packed; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + StrideOPacked const stride_O_partial_packed; + float* ptr_LSE; + StrideLSE const stride_LSE; + ShapeLSEPacked const shape_LSE_packed; + StrideLSEPacked const stride_LSE_packed; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + StrideLSEPacked const stride_LSE_partial_packed; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_O tma_store_O; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); + TMA_O tma_store_O = [&]{ + if constexpr (Use_TMA_O) { + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast + } else { + return nullptr; + } + }(); + // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); + auto const shape_O_packed = cute::conditional_return( + args.shape_O, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + ); + auto const stride_O_packed = cute::conditional_return( + args.stride_O, + make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) + ); + auto const stride_O_partial_packed = cute::conditional_return( + args.stride_O_partial, + make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) + ); + // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) + auto const shape_LSE_packed = cute::conditional_return( + select<0, 2, 3, 4>(args.shape_O), + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + ); + auto const stride_LSE_packed = cute::conditional_return( + args.stride_LSE, + make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) + ); + auto const stride_LSE_partial_packed = cute::conditional_return( + args.stride_LSE_partial, + make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) + ); + return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, + args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, + args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, + args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, + cutlass::FastDivmod(qhead_per_khead), + tma_store_O, args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA_O) { + cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO& tOrO, + FrgTensorLSE const& lse, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); + // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); + + static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); + // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. + // Otherwise we can permute after conversion. + if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } + Tensor tOrO_out = make_tensor_like(tOrO); + flash::convert_type_out(tOrO, tOrO_out); + if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } + + // Make sure all WGs have finished reading V + // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that + // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with + // cp.async if we need). + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + // Step 1: Write O from rmem -> smem + if constexpr (Use_smem) { + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + if constexpr (Use_TMA_O) { + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } else { + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + } else { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; + bool is_varlen = Varlen && params.cu_seqlens; + int offset_o = seqlen_info.offset; + int seqlen_o = seqlen_info.seqlen; + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + + // Step 2: Write LSE from rmem -> gmem + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + // (MMA,MMA_M,MMA_K) + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; + + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); + // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } + if (!LargeHeadDimV || warp_group_idx == 0) { + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + + // Step 3: Write O from smem -> gmem + if constexpr (Use_TMA_O) { + Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_O = params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOrO = make_fragment_like(tOsO); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + if constexpr (ArchTag::kMinComputeCapability >= 90) { + cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + if constexpr (!PackGQA) { + // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } else { + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // We already arrived on barrier_O earlier if !Use_smem + if constexpr (Use_smem) { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + if constexpr (!PackGQA) { + static constexpr int kGmemElemsPerStoreDirect = 2; + cute::Copy_Atom, ElementPartial> gmem_copy_direct; + // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); + Tensor tOgO = thread_mma.partition_C(gOpartial); + Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); + Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); + Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); + #pragma unroll + for (int m = 0; m < size(taccOcO_row); ++m) { + if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { + #pragma unroll + for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { + if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { + cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); + } + } + } + } + } else { + PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + } + } + + CUTLASS_DEVICE void + store_tail() { + // Don't need to do tma_store_wait<0>() here since we already did in @store + } + + // Write 0 to output and -inf to LSE + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + int offset_o = seqlen_info.offset; + int seqlen_o = seqlen_info.seqlen; + int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); + Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); + + static_assert(kBlockM <= NumEpilogueThreads); + if (thread_idx < kBlockM) { + const int row = m_block * kBlockM + thread_idx; + if constexpr (!PackGQA) { + if (row < seqlen_o) { mLSE(row) = -INFINITY; } + } else { + if (row < seqlen_o * qhead_per_khead) { + int m_idx, h_idx; + m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); + // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" + mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; + } + } + } + + // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, + // since it will not use the value of O if LSE is -inf. + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + if constexpr (!PackGQA) { + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + cute::clear(tOrO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); + cute::clear(tOrO); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash.h b/flash-attn/flash.h new file mode 100644 index 0000000000000000000000000000000000000000..bee89e5f054992919136f97e0b9c118d6790c773 --- /dev/null +++ b/flash-attn/flash.h @@ -0,0 +1,218 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t v_dim_stride; + + // The number of heads. + int h, h_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + using index_t = int64_t; + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // For FP8 scaling + float * __restrict__ q_descale_ptr; + float * __restrict__ k_descale_ptr; + float * __restrict__ v_descale_ptr; + index_t q_descale_batch_stride; + index_t q_descale_head_stride; + index_t k_descale_batch_stride; + index_t k_descale_head_stride; + index_t v_descale_batch_stride; + index_t v_descale_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int total_q, total_k, total_knew; + int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim + + // The scaling factors for the kernel. + float scale_softmax; + float softcap; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ cu_seqlens_knew; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each q/k sequence. + int *__restrict__ seqused_q; + int *__restrict__ seqused_k; + + // The stride between rows of Oaccum. + index_t oaccum_split_stride; + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + + // The stride between rows of LSEaccum. + index_t lseaccum_split_stride; + index_t lseaccum_batch_stride; + index_t lseaccum_head_stride; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + void *__restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + int *__restrict__ seqlens_rotary; + + // The indices to index into the KV cache. + int * __restrict__ kv_batch_idx; + + // Paged KV cache + int * __restrict__ page_table; + index_t page_table_batch_stride; + int page_size; + int num_pages; + bool pagedkv_tma; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + + // Local window size + int window_size_left, window_size_right; + int attention_chunk; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_fp32; + bool is_e4m3; + bool is_causal; + bool is_local; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + bool pack_gqa; + + int * __restrict__ tile_count_semaphore; + // int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; + int * __restrict__ num_splits_dynamic_ptr; + bool skip_scheduler_metadata_computation; + + int arch; + int num_sm; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + using index_t = int64_t; + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + void *__restrict__ softmax_lse_log2_ptr; + + int *__restrict__ dq_semaphore; + int *__restrict__ dk_semaphore; + int *__restrict__ dv_semaphore; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); +template +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/flash-attn/flash_api.cpp b/flash-attn/flash_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4790f4e2df5a83359183f41e2e45e7ec11f1e4b --- /dev/null +++ b/flash-attn/flash_api.cpp @@ -0,0 +1,1720 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" +#include "tile_size.h" +#include "heuristics.h" +#include "cuda_check.h" + + +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + const int sm_margin=0) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (cu_seqlens_k_d == nullptr) { + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { window_size_left = seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor out, + const at::Tensor dout, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + bool deterministic=false, + int const sm_margin=0) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_q, + seqused_k, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +template +void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHATTENTION_DISABLE_FP8 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + } + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP8."); + #endif + } +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_fwd_(params, stream); + // }); + TORCH_CHECK(params.num_splits >= 1); + ARCH_SWITCH(params.arch, Arch, [&] { + SPLIT_SWITCH(params.num_splits > 1, Split, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { + PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; + SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { + run_mha_fwd_constexpr(params, stream); + }); + }); + }); + }); + }); +} + +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { + #ifndef FLASHATTENTION_DISABLE_SPLIT + // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + // so that kBlockM is smaller and we have more parallelism. + if (params.is_fp32) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else if (params.is_bf16) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } + #else + TORCH_CHECK(false, "This flash attention build does not support combine kernels."); + #endif +} + +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + #ifdef FLASHATTENTION_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHATTENTION_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); + #endif +} + +inline int get_max_headdim() { + #ifndef FLASHATTENTION_DISABLE_HDIM256 + return 256; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + return 192; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + return 128; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + return 96; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM64 + return 64; + #endif + return 0; +} + +inline int round_up_headdim(int head_size) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (head_size <= 64) { return 64; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (head_size <= 96) { return 96; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (head_size <= 128) { return 128; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (head_size <= 192) { return 192; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (head_size <= 256) { return 256; } + #endif + return 256; +} + +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +at::Tensor +mha_fwd_get_scheduler_metadata( + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, + at::ScalarType qkv_dtype, + at::Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + bool has_softcap, + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ) { + + TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr() : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr() : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr() : nullptr; + params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr; + params.seqused_k = seqused_k.data_ptr(); + params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr() : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_dynamic_split = params.b <= 992; + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + + auto opts = seqused_k.options(); + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + if (scheduler_needs_semaphore || use_dynamic_split) { + tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + if (scheduler_needs_semaphore) { + if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + } else { + params.tile_count_semaphore = nullptr; + } + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + } + + if (params.num_splits_dynamic_ptr) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + +// b: batch_size +// b_k: batch_size_k +// s_q: seqlen_q +// s_k: seqlen_k +// s_k_new: seqlen_k_new +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple +mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major >= 8; + TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + if (dprops->major < 9) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); + } + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor page_table; + const bool paged_KV = page_table_.has_value(); + if (paged_KV) { + page_table = page_table_.value(); + CHECK_DEVICE(page_table); + TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + } + + at::Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + at::Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + } + + auto const sizes = q.sizes(); + const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; + int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); + int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + int const max_headdim = get_max_headdim(); + TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); + TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()) { + auto seqused_k = seqused_k_.value(); + TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + } + } else { + out = !is_varlen_q + ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + at::Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + } else { + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } + if (paged_KV) { + params.page_table = page_table.data_ptr(); + params.page_table_batch_stride = page_table.stride(0); + } + params.page_size = page_size; + params.num_pages = num_pages; + + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + at::Tensor k_new, v_new; + TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + at::Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.value(); + CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); + TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); + } + k_new = k_new_.value(); + v_new = v_new_.value(); + TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); + TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); + CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); + TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + params.seqlen_knew = seqlen_k_new; + params.total_knew = total_k_new; + params.knew_ptr = k_new.data_ptr(); + params.vnew_ptr = v_new.data_ptr(); + // All stride are in elements, not bytes. + params.knew_row_stride = k_new.stride(-3); + params.vnew_row_stride = v_new.stride(-3); + params.knew_head_stride = k_new.stride(-2); + params.vnew_head_stride = v_new.stride(-2); + if (!is_varlen_k_new) { + params.knew_batch_stride = k_new.stride(0); + params.vnew_batch_stride = v_new.stride(0); + } + if (is_varlen_k_new) { + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + } + } + + // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel + bool const use_dynamic_split = is_varlen && params.b <= 992; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + if (scheduler_needs_semaphore || use_dynamic_split) { + int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + at::Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); + } + if (scheduler_needs_semaphore && !use_dynamic_split) { + tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing + } + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + } + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + if (paged_KV) { + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + at::Tensor out_accum, softmax_lse_accum; + auto outaccum_type = at::ScalarType::Float; + if (params.num_splits > 1) { + TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + if (!is_varlen_q) { + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + params.oaccum_batch_stride = out_accum.stride(1); + params.lseaccum_batch_stride = softmax_lse_accum.stride(1); + } else { + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); + } + params.is_fp32 = false; + params.oaccum_ptr = out_accum.data_ptr(); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_split_stride = out_accum.stride(0); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.lseaccum_split_stride = softmax_lse_accum.stride(0); + params.lseaccum_head_stride = softmax_lse_accum.stride(-2); + } + + if (q_type == at::ScalarType::Float8_e4m3fn) { + if (q_descale_.has_value()) { + auto q_descale = q_descale_.value(); + CHECK_DEVICE(q_descale); + CHECK_SHAPE(q_descale, batch_size, num_heads_k); + params.q_descale_ptr = q_descale.data_ptr(); + params.q_descale_batch_stride = q_descale.stride(0); + params.q_descale_head_stride = q_descale.stride(1); + } else { + params.q_descale_ptr = nullptr; + } + if (k_descale_.has_value()) { + auto k_descale = k_descale_.value(); + CHECK_DEVICE(k_descale); + CHECK_SHAPE(k_descale, batch_size, num_heads_k); + params.k_descale_ptr = k_descale.data_ptr(); + params.k_descale_batch_stride = k_descale.stride(0); + params.k_descale_head_stride = k_descale.stride(1); + } else { + params.k_descale_ptr = nullptr; + } + if (v_descale_.has_value()) { + auto v_descale = v_descale_.value(); + CHECK_DEVICE(v_descale); + CHECK_SHAPE(v_descale, batch_size, num_heads_k); + params.v_descale_ptr = v_descale.data_ptr(); + params.v_descale_batch_stride = v_descale.stride(0); + params.v_descale_head_stride = v_descale.stride(1); + } else { + params.v_descale_ptr = nullptr; + } + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + #ifdef FLASHATTENTION_DISABLE_SPLIT + TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); + #endif + #ifdef FLASHATTENTION_DISABLE_PACKGQA + TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + #endif + #ifdef FLASHATTENTION_DISABLE_PAGEDKV + TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); + #endif + #ifdef FLASHATTENTION_DISABLE_APPENDKV + TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); + #endif + + if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + if (params.num_splits > 1) { + if (out_type == at::ScalarType::BFloat16) { + // Since we want output in BF16. Otherwise fwd_combine will output to FP16 + params.is_bf16 = true; + } + // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 + // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. + // if (is_varlen_q && !seqused_q_.has_value()) { + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // This will zero out the semaphore if needed + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); + } + } else if (total_q > 0 && num_heads_k > 0) { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + // return {out, softmax_lse}; + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +#ifdef FLASHATTENTION_DISABLE_BACKWARD +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +template +void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { + if (!params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + } +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + ARCH_SWITCH(params.arch, Arch, [&] { + SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { + run_mha_bwd_constexpr(params, stream); + }); + }); +} +#endif + + +// b: batch_size +// s_q: seqlen_q +// s_k: seqlen_k +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple mha_bwd( + at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major >= 8; + TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.dtype(); + TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + at::Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + at::Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + } + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + auto const sizes = q.sizes(); + int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; + int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); + int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int const num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); + int const max_headdim = get_max_headdim(); + TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + if (is_causal) { window_size_right = 0; } + // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. + // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). + is_causal = window_size_left < 0 && window_size_right == 0; + + int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); + int const head_size_v_rounded = head_size_rounded; + // Very important that these match the kernel configs + bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) + : (head_size_rounded <= 96 ? 64 + : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) + : 64)); + int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; + int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; + int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); + int const kBlockN_sm90 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 96 : 80); + int const kBlockN_sm80 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 80 : 64); + int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 + : (head_size_rounded <= 96 ? 128 + : (head_size_rounded <= 128 ? 96 + : (head_size_rounded <= 192 ? 64 : 64))); + int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); + int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); + int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + CHECK_SHAPE(dout, total_q, num_heads, head_size_v); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()){ + auto seqused_k = seqused_k_.value(); + TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); + } + } else { + dv = torch::empty_like(v); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + at::Tensor softmax_d, softmax_lse_log2; + if (!is_varlen) { + // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + } else { + softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); + softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); + } + at::Tensor dq_accum, dk_accum, dv_accum; + if (!is_varlen) { + dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + } else { + dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + } + if (num_heads_k != num_heads) { // MQA / GQA + if (!is_varlen) { + dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat)); + } else { + dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat)); + } + } + + Flash_bwd_params params; + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout, dq, dk, dv, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + dq_accum.data_ptr(), + num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, + num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + 0, // attention_chunk + softcap, + deterministic, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + + // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + // Will be zero'ed out in the backward preprocess kernel + at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); + params.dq_semaphore = dq_semaphore.data_ptr(); + if (num_heads_k != num_heads && params.deterministic) { + // TODO: do we need to zero them out? + at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + params.dk_semaphore = dk_semaphore.data_ptr(); + params.dv_semaphore = dv_semaphore.data_ptr(); + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + + if (total_q > 0 && total_k > 0 && num_heads_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_bwd(params, stream); + } else if (total_k > 0 && num_heads_k > 0) { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } else if (total_q > 0 && num_heads_k > 0) { + dq.zero_(); + softmax_d.zero_(); + } + + return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; +} + +std::tuple +mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads + std::optional out_, // batch_size x seqlen x num_heads x head_size + std::optional out_dtype_ + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major >= 8; + TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); + + auto out_partial_type = out_partial.scalar_type(); + TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type"); + TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type"); + + CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); + + TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); + + const auto sizes = out_partial.sizes(); + + const int num_splits = sizes[0]; + const int batch_size = sizes[1]; + const int seqlen = sizes[2]; + const int num_heads = sizes[3]; + const int head_size_og = sizes[4]; + TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); + + CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); + CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); + + int const alignment = 4; + at::Tensor out_partial_padded; + auto pad = [](at::Tensor x, int alignment) { + return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment})); + }; + out_partial_padded = pad(out_partial, alignment); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, alignment); + + auto opts = out_partial.options(); + at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); + TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.scalar_type() == out_type); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); + if (head_size_og % alignment != 0) { + out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); + } + } else { + out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()}; + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2); + + Flash_fwd_params params {}; // Need to reset the params to set everything to zero + params.is_fp32 = out_type == at::ScalarType::Float; + params.is_bf16 = out_type == at::ScalarType::BFloat16; + params.oaccum_ptr = out_partial_padded.data_ptr(); + params.softmax_lseaccum_ptr = lse_partial.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + params.b = batch_size; + params.h = num_heads; + params.seqlen_q = seqlen; + params.dv = head_size; + params.num_splits = num_splits; + params.oaccum_split_stride = out_partial_padded.stride(0); + params.oaccum_row_stride = out_partial_padded.stride(2); + params.oaccum_head_stride = out_partial_padded.stride(3); + params.oaccum_batch_stride = out_partial_padded.stride(1); + params.lseaccum_split_stride = lse_partial.stride(0); + params.lseaccum_head_stride = lse_partial.stride(3); + params.lseaccum_batch_stride = lse_partial.stride(1); + params.o_row_stride = out.stride(1); + params.o_head_stride = out.stride(2); + params.o_batch_stride = out.stride(0); + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + + if (seqlen > 0 && batch_size > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); + } + + at::Tensor out_padded = out; + if (head_size_og % alignment != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + +#ifdef false + +TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &mha_fwd); + m.impl("bwd", &mha_bwd); + m.impl("fwd_combine", &mha_combine); + m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); +} + +#endif diff --git a/flash-attn/flash_bwd_kernel_sm80.h b/flash-attn/flash_bwd_kernel_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..aaec00dbe4a493a159629a19df75e5eaafac2c83 --- /dev/null +++ b/flash-attn/flash_bwd_kernel_sm80.h @@ -0,0 +1,173 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdSm80 { + +public: + + // Type Aliases + static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; + static constexpr bool Is_local = CollectiveMainloop_::Is_local; + static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); + static constexpr bool Varlen = CollectiveMainloop_::Varlen; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; + using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{})); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); + // Initialize matmul objects. + TiledMmadKV tiled_mma_dKV; + + scheduler.init_consumer(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + + // dK and dV output accumulator. + Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, + block_coord, shared_storage); + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + if (tile_valid) { + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x, block_coord); + } else { + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); + } + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_bwd_kernel_sm90.h b/flash-attn/flash_bwd_kernel_sm90.h new file mode 100644 index 0000000000000000000000000000000000000000..b93a0219161556b0d14f33008e0a448294d609ef --- /dev/null +++ b/flash-attn/flash_bwd_kernel_sm90.h @@ -0,0 +1,282 @@ + +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdSm90 { + +public: + + // Type Aliases + static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; + static constexpr bool Is_local = CollectiveMainloop_::Is_local; + static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); + static constexpr bool Varlen = CollectiveMainloop_::Varlen; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; + using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32; + static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160; + // If you want to print from the producer warp, you'd need to increase the number of registers + // Otherwise you'll get CUDA error. + // static constexpr uint32_t LoadRegisterRequirement = 40; + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV; + alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q; + alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do; + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + } pipelines; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO; + using PipelineParams_dO = typename MainloopPipeline_dO::Params; + using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; + static constexpr bool Q_dO_same_stages = std::is_same_v; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/); + } + // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{}); + auto role_dO = warp_group_idx == 0 + ? MainloopPipeline_dO::ThreadCategory::Producer + : MainloopPipeline_dO::ThreadCategory::Consumer; + PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers}; + MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return(pipeline_params, pipeline_params_dO), ClusterShape{}); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state(); + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + }; + mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, + smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); + } + mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); + } else if (warp_idx_in_warpgroup == 1) { + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + mainloop.store_dq(params.mainloop, shared_storage, block_coord); + } + } + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + // Initialize matmul objects. + TiledMmadKV tiled_mma_dKV; + + PipelineState smem_pipe_read; + PipelineState_dO smem_pipe_read_do; + + mainloop.mma_init(); + scheduler.init_consumer(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { + auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); + auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; + cute::tuple block_coord = {n_block, bidh, bidb}; + + // dK and dV output accumulator. + Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); + bool tile_valid = mainloop.mma( + params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, + tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); + if (tile_valid) { + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x - NumCopyThreads, block_coord); + } else { + epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); + } + + } + epilogue.store_tail(); + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_bwd_launch_template.h b/flash-attn/flash_bwd_launch_template.h new file mode 100644 index 0000000000000000000000000000000000000000..b6e8810b25fa9af90266324d98b09dae4e239d73 --- /dev/null +++ b/flash-attn/flash_bwd_launch_template.h @@ -0,0 +1,390 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch +#include "cutlass/cluster_launch.hpp" // For ClusterLauncher + +#include "static_switch.h" +#include "flash.h" +#include "flash_bwd_preprocess_kernel.h" +#include "flash_bwd_postprocess_kernel.h" +#include "tile_scheduler.hpp" +#include "mainloop_bwd_sm90_tma_gmma_ws.hpp" +#include "mainloop_bwd_sm80.hpp" +#include "epilogue_bwd.hpp" +#include "flash_bwd_kernel_sm90.h" +#include "flash_bwd_kernel_sm80.h" + +using namespace cute; + +template +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); + using ElementAccum = float; + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + + int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM); + int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN); + bool const is_varlen_q = params.cu_seqlens_q; + bool const is_varlen_k = params.cu_seqlens_k; + int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; + int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k; + int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded; + int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded; + int batch_q = !is_varlen_q ? params.b : 1; + int batch_k = !is_varlen_k ? params.b : 1; + + using TileShape_MK = cute::Shape, Int>; + using PreprocessKernel = flash::FlashAttnBwdPreprocess; + typename PreprocessKernel::Arguments preprocess_args { + static_cast(params.o_ptr), + {seqlen_q, params.dv, params.h, batch_q}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O + static_cast(params.do_ptr), + {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO + static_cast(params.dsoftmax_sum), + {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE + static_cast(params.softmax_lse_log2_ptr), + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum + params.b, + params.dq_semaphore, + params.cu_seqlens_q, + params.seqused_q + }; + typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); + int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); + dim3 grid_m(num_m_block, params.h, params.b); + cutlass::kernel_launch(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + + using TileShape_MNK = cute::Shape, Int, Int>; + using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster + // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80 + static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80; + static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1; + using CollectiveMainloop = std::conditional_t< + Arch >= 90, + flash::CollectiveMainloopBwdSm90, + flash::CollectiveMainloopBwdSm80 + >; + using CollectiveEpilogue = std::conditional_t< + !GQA, + flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, + flash::CollectiveEpilogueBwdGQA + >; + using Scheduler = std::conditional_t< + Is_causal && !Varlen, + flash::SingleTileBwdLPTScheduler, + flash::SingleTileScheduler + >; + using AttnKernel = std::conditional_t< + Arch >= 90, + flash::enable_sm90_or_later>, + flash::enable_sm80_to_sm89> + >; + + typename CollectiveMainloop::Arguments mainloop_args { + static_cast(params.q_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q + static_cast(params.k_ptr), + {seqlen_k, params.d, params.h_k, batch_k}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K + static_cast(params.v_ptr), + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V + {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V + static_cast(params.do_ptr), + {seqlen_q, params.dv, params.h, batch_q}, // shape_dO + {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum + static_cast(params.softmax_lse_log2_ptr), + {seqlen_q_rounded, params.h, batch_q}, // shape_LSE + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 + static_cast(params.dsoftmax_sum), + {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum + params.scale_softmax, + params.window_size_left, params.window_size_right, 0 /*attention_chunk*/, + params.softcap, + params.b, + params.dq_semaphore, + params.cu_seqlens_q, params.cu_seqlens_k, + params.seqused_q, params.seqused_k + }; + // The case work with GQA is ugly but idk how to fix it. + typename CollectiveEpilogue::Arguments epilogue_args { + static_cast(!GQA ? params.dk_ptr : params.dk_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum + } + }(), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK + } else { + return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum + } + }(), + static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum + } + }(), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV + } else { + return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + } + }(), + params.h, + params.dk_semaphore, + params.dv_semaphore, + params.cu_seqlens_k, + params.seqused_k, + }; + + int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{})); + num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{})); + typename flash::TileSchedulerArguments scheduler_args { + num_blocks_n, params.h, params.b, 1 /*num_splits*/, + params.h / params.h_k, + params.seqlen_k, + params.seqlen_q, params.d, params.dv, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k + }; + + int device; + cudaGetDevice(&device); + typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ + mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args + }); + + dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); + dim3 block_dims = AttnKernel::get_block_shape(); + int smem_size = AttnKernel::SharedStorageSize; + // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); + // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do)); + // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds)); + // int smem_size_dqacc = [&] { + // if constexpr (Arch >= 90) { + // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc)); + // } else { + // return 0; + // } + // }(); + // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); + // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse)); + // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum)); + // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum); + if constexpr (size(ClusterShape{}) > 1) { + void const* kernel = (void const*) cutlass::device_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLauncher::launch( + grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/); + } else { + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/); + } + CHECK_CUDA_KERNEL_LAUNCH(); + + using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ; + typename PostprocessKernel::Arguments postprocess_args { + static_cast(params.dq_accum_ptr), + {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum + static_cast(params.dq_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_dQ + {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ + params.scale_softmax, + params.cu_seqlens_q, + params.seqused_q + }; + typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); + int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); + dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b); + int smem_size_postprocess = PostprocessKernel::SharedStorageSize; + if (smem_size_postprocess >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); + } + cutlass::kernel_launch(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + + if constexpr (GQA) { + using TileShape_NK = cute::Shape, Int>; + using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ; + typename PostprocessKerneldKV::Arguments postprocess_dK_args { + static_cast(params.dk_accum_ptr), + {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum + {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum + static_cast(params.dk_ptr), + {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK + {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK + 1.f, + params.cu_seqlens_k, + params.seqused_k + }; + typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); + typename PostprocessKerneldKV::Arguments postprocess_dV_args { + static_cast(params.dv_accum_ptr), + {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum + {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + static_cast(params.dv_ptr), + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV + {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV + 1.f, + params.cu_seqlens_k, + params.seqused_k + }; + typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args); + int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{})); + dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b); + int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize; + if (smem_size_postprocess >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); + } + cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + +} + +template +void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) { + VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.h != params.h_k, GQA, [&] { +// BOOL_SWITCH(params.deterministic, Deterministic, [&] { + // run_flash_bwd(params, stream); + run_flash_bwd(params, stream); +// }); + }); + }); +} + + +template +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + if constexpr (Arch >= 90) { + if constexpr (Is_causal && Has_softcap) { + // register spill with 128 x 128 + run_mha_bwd_dispatch(params, stream); + } else { + // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block. + run_mha_bwd_dispatch(params, stream); + } + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} + +template +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + if constexpr (Arch >= 90) { + run_mha_bwd_dispatch(params, stream); + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} + +template +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + if constexpr (Arch >= 90) { + if constexpr (Is_causal || Is_local || Has_softcap) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} + +template +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + if constexpr (Arch >= 90) { + run_mha_bwd_dispatch(params, stream); + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} + +template +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + if constexpr (Arch >= 90) { + run_mha_bwd_dispatch(params, stream); + } else if constexpr (Arch == 86 || Arch == 89) { + run_mha_bwd_dispatch(params, stream); + // run_mha_bwd_dispatch(params, stream); + } else { + run_mha_bwd_dispatch(params, stream); + } + }); +} diff --git a/flash-attn/flash_bwd_postprocess_kernel.h b/flash-attn/flash_bwd_postprocess_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c91e261507dd906ebb457db3d3f5dd10ef4b5109 --- /dev/null +++ b/flash-attn/flash_bwd_postprocess_kernel.h @@ -0,0 +1,256 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include "cutlass/arch/barrier.h" + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdPostprocessConvertdQ { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + + static_assert(ArchTag::kMinComputeCapability >= 75); + static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90; + + static constexpr uint32_t MaxThreadsPerBlock = kNThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kHeadDim = get<1>(TileShape_MK{}); + static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup"); + static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup; + using R2SLayoutAtomdQaccum = std::conditional_t< + IsSm90, + Layout, Int>>, + Layout>> + >; + using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, + Layout>>{})); // Val layout, 1 or 4 vals per read + using G2SLayoutAtomdQaccum = Layout>>; + // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions + using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, G2SLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per read + // We don't do bound checking for the gmem -> smem load so we just assert here. + static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0); + static constexpr int SmemdQaccumSize = size(TileShape_MK{}); + using SmemLayoutdQaccumFlat = Layout>>; + using SmemLayoutdQaccum = std::conditional_t< + IsSm90, + Layout, Int>>, + Layout>> + >; + + // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, + // then setting kBlockKSmem to 32 will cause "Static shape_div failure". + // We want to treat it as 64 x 48, so kBlockKSmem should be 16. + static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{}); + static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16); + static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); + using SmemLayoutAtomdQ = + decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{})); + using SmemLayoutdQt = + decltype(cute::composition(SmemLayoutdQ{}, + make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})), + make_stride(Int(TileShape_MK{})>{}, _1{})))); + + using SmemCopyAtomdQ = Copy_Atom< + std::conditional_t< + IsSm90, + std::conditional_t, + AutoVectorizingCopyWithAssumedAlignment<128> + >, + Element>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock)); + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_dqacc; + cute::array_aligned> smem_dq; + alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) + using StridedQ = cute::Stride; + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + // Device side arguments + struct Arguments { + ElementAccum const* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + Element* ptr_dQ; + ShapedQ const shape_dQ; + StridedQ const stride_dQ; + float const softmax_scale; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Kernel entry point API + struct Params { + ElementAccum const* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + Element* ptr_dQ; + ShapedQ const shape_dQ; + StridedQ const stride_dQ; + float const softmax_scale; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + return { + args.ptr_dQaccum, + args.shape_dQaccum, + args.stride_dQaccum, + args.ptr_dQ, + args.shape_dQ, + args.stride_dQ, + args.softmax_scale, + args.cu_seqlens, + args.seqused + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{}); + Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{}); + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{}); + Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const bidh = blockIdx.y; + int const bidb = blockIdx.z; + + flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused); + bool const is_varlen = params.cu_seqlens; + if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; } + + // Step 1: load dQaccum from gmem to smem + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) + if constexpr (IsSm90) { // Use BulkCopy + static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v / 8); + auto bulk_copy = Copy_Traits{}; + // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); } + if (thread_idx == 0) { + shared_storage.barrier_dQaccum.init(1 /*numThreads*/); + shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); + copy(bulk_copy.with(*reinterpret_cast(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat); + } + __syncthreads(); + shared_storage.barrier_dQaccum.wait(0); + } else { + G2STiledCopydQaccum g2s_tiled_copy_dQaccum; + auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum); + Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum); + cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s); + __syncthreads(); + } + + // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); } + + // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16 + R2STiledCopydQaccum s2r_tiled_copy_dQaccum; + auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum); + TiledMma tiled_mma_dQ; + Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select(TileShape_MK{})); + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); } + CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); + Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); + cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; } + // Convert tdQrdQ from fp32 to fp16 + Tensor rdQ = make_tensor_like(taccdQrdQaccum); + flash::convert_type_out(taccdQrdQaccum, rdQ); + + // Step 3: Copy dQ from register to smem + auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + // if (cute::thread0()) { print(smem_tiled_copy_dQ); } + // if (cute::thread0()) { print(smem_thr_copy_dQ); } + // if (cute::thread0()) { print(sdQ); } + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + + // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + GmemTiledCopy gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + + Tensor tdQrdQ = make_fragment_like(tdQsdQ); + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{})); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); } + // Need to check OOB when reading from smem if kBlockM isn't evenly tiled + static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; + flash::copy( + gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM); + + // Step 5: Copy dQ from register to gmem + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM) + ); + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_bwd_preprocess_kernel.h b/flash-attn/flash_bwd_preprocess_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..85e877f9d4fd2821ef38f78b12f1e1fdde2b6d37 --- /dev/null +++ b/flash-attn/flash_bwd_preprocess_kernel.h @@ -0,0 +1,252 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnBwdPreprocess { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + + static_assert(std::is_same_v && ArchTag::kMinComputeCapability >= 75 || + std::is_same_v && ArchTag::kMinComputeCapability >= 80 || + std::is_same_v && ArchTag::kMinComputeCapability >= 89); + + static constexpr uint32_t MaxThreadsPerBlock = 256; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + static constexpr int SharedStorageSize = 0; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kHeadDim = get<1>(TileShape_MK{}); + // We want kBlockKGmem to be a power of 2 so that when we do the summing, + // it's just between threads in the same warp + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum"); + using GmemLayoutAtomAccum = Layout>>; + using GmemTiledCopyAccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomAccum{}, + Layout>>{})); // Val layout, 4 vals per store + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) + using StrideO = cute::Stride; + using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) + using StridedPsum = cute::Stride<_1, int64_t, int64_t>; + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + // Device side arguments + struct Arguments { + Element const* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + Element const* ptr_dO; + StrideO const stride_dO; + float* ptr_dPsum; + ShapedPsum const shape_dPsum; + StridedPsum const stride_dPsum; + float const* ptr_LSE; + StridedPsum const stride_LSE; + float *ptr_LSE_log2; + StridedPsum const stride_LSE_log2; + ElementAccum* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + int num_batch; // We need this to know the size of dq_semaphore in case of varlen + int* dq_semaphore; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Kernel entry point API + struct Params { + Element const* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + Element const* ptr_dO; + StrideO const stride_dO; + float* ptr_dPsum; + ShapedPsum const shape_dPsum; + StridedPsum const stride_dPsum; + float const* ptr_LSE; + StridedPsum const stride_LSE; + float* ptr_LSE_log2; + StridedPsum const stride_LSE_log2; + ElementAccum* ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + int num_batch; + int* dq_semaphore; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + return { + args.ptr_O, + args.shape_O, + args.stride_O, + args.ptr_dO, + args.stride_dO, + args.ptr_dPsum, + args.shape_dPsum, + args.stride_dPsum, + args.ptr_LSE, + args.stride_LSE, + args.ptr_LSE_log2, + args.stride_LSE_log2, + args.ptr_dQaccum, + args.shape_dQaccum, + args.stride_dQaccum, + args.num_batch, + args.dq_semaphore, + args.cu_seqlens, + args.seqused + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, [[maybe_unused]] char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const bidh = blockIdx.y; + int const bidb = blockIdx.z; + + flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused); + bool const is_varlen = Varlen && params.cu_seqlens; + int const seqlen_o = seqlen_info.seqlen; + if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } + + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) + + auto shape_LSE = select<0, 2, 3>(params.shape_O); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0); + Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape>{}, make_coord(m_block)); + static_assert(kBlockM <= MaxThreadsPerBlock); + float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY; + + GmemTiledCopy gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor tOgO = gmem_thr_copy_O.partition_S(gO); + Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO); + // Construct identity layout for gO + Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + + // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128) + Tensor tOrO = make_fragment_like(tOgO); + Tensor tOrdO = make_fragment_like(tOgdO); + flash::copy( + gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + flash::copy( + gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));} + + // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64)) + Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout()))); + Tensor tOrO_l = make_tensor(tOrO.data(), l); + Tensor o_fp32 = make_tensor_like(tOrO_l); + flash::convert_type_out(tOrO_l, o_fp32); + Tensor tOrdO_l = make_tensor(tOrdO.data(), l); + Tensor do_fp32 = make_tensor_like(tOrdO_l); + flash::convert_type_out(tOrdO_l, do_fp32); + // Sum across the last dimension + Tensor dP_sum = make_tensor(make_shape(size<0>(o_fp32))); + #pragma unroll + for (int mi = 0; mi < size<0>(o_fp32); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(o_fp32); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + flash::SumOp sum_op; + dP_sum(mi) = flash::Allreduce::run(dP_sum_cur, sum_op); + } + + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape>{}, make_coord(m_block)); + if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) { + #pragma unroll + for (int mi = 0; mi < size(dP_sum); ++mi) { + int const row = get<0>(tOcO(_0{}, mi, _0{})); + gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0; + } + } + + int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM); + Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0); + Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape>{}, make_coord(m_block)); + if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) { + gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); + } + + if constexpr (Clear_dQaccum) { + Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); + GmemTiledCopyAccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(Copy_Atom, ElementAccum>{}, zero, tdQgdQaccum); + } + + if (params.dq_semaphore != nullptr && thread_idx == 0) { + int const num_batch = params.num_batch; + int const num_head = get<2>(params.shape_O); + params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0; + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_fwd_combine.cu b/flash-attn/flash_fwd_combine.cu new file mode 100644 index 0000000000000000000000000000000000000000..3e85a0a212cb831a62363dac16f5016f27271934 --- /dev/null +++ b/flash-attn/flash_fwd_combine.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_combine_launch_template.h" + +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); + +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); + +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/flash-attn/flash_fwd_combine_kernel.h b/flash-attn/flash_fwd_combine_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a22e05969d9a964682a91b082faafc4e60b27bdd --- /dev/null +++ b/flash-attn/flash_fwd_combine_kernel.h @@ -0,0 +1,482 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include + +#include "cutlass/arch/grid_dependency_control.h" + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdCombine { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + static constexpr int kMaxSplits = 1 << kLogMaxSplits_; + static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float))); + static_assert(AlignmentLSE >= 1); + static constexpr int kStages = 4; + + static_assert(ArchTag::kMinComputeCapability >= 75); + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr uint32_t MaxThreadsPerBlock = kNThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kBlockK = get<1>(TileShape_MK{}); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); + static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemCopyAtom = std::conditional_t< + Has_cp_async, + cute::Copy_Atom, ElementPartial>, + cute::Copy_Atom, ElementPartial> + >; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + using GmemTiledCopyAccum = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 4 vals per load + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 4 vals per load + + using AlignmentTypeLSE = cute::uint_byte_t(sizeof(float)) * AlignmentLSE>; + static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float); + static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE"); + static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8"); + static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8))); + static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE"); + using GmemLayoutAtomLSE = Layout, Int>, + Stride, _1>>; + static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0); + using GmemCopyAtomLSE = std::conditional_t< + Has_cp_async, + cute::Copy_Atom, float>, + cute::Copy_Atom, float> + >; + using GmemTiledCopyLSE = decltype( + make_tiled_copy(GmemCopyAtomLSE{}, + GmemLayoutAtomLSE{}, + Layout>>{})); // Val layout, 4 vals per load + + // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking + static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE"); + // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + using SmemLSESwizzle = std::conditional_t< + kBlockMSmem == 8, + Swizzle<5, 0, 5>, + std::conditional_t, Swizzle<3, 2, 3>> + >; + using SmemLayoutAtomLSE = + decltype(composition(SmemLSESwizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); + + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; + + // We want each column (kMaxSplits) to be processed by threads in the same warp. + // To reduce the number of shuffles, we want as few threads on the same column as possible. + // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column + // have have 64 such quads. + static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem"); + static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem; + static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp"); + using S2RLayoutAtomLSE = Layout, Int>>; + using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom{}, S2RLayoutAtomLSE{}, Layout<_1>{})); + + using ShapeOPartial = cute::Shape; // (seqlen, d, num_splits, head, batch) + using StrideOPartial = cute::Stride; + using ShapeLSEPartial = cute::Shape; // (seqlen, num_splits, head, batch) + using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch) + using ShapeO = cute::Shape; // (seqlen, d, head, batch) + using StrideO = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_lse_partial; + cute::array_aligned smem_max_valid_split; + cute::array_aligned> smem_o_partial; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + ElementPartial const* const ptr_O_partial; + ShapeOPartial const shape_O_partial; + StrideOPartial const stride_O_partial; + float const* const ptr_LSE_partial; + ShapeLSEPartial const shape_LSE_partial; + StrideLSEPartial const stride_LSE_partial; + Element* const ptr_O; + StrideO const stride_O; + float* const ptr_LSE; + StrideLSE const stride_LSE; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; + }; + + // Kernel entry point API + struct Params { + ElementPartial const* const ptr_O_partial; + ShapeOPartial const shape_O_partial; + StrideOPartial const stride_O_partial; + float const* const ptr_LSE_partial; + ShapeLSEPartial const shape_LSE_partial; + StrideLSEPartial const stride_LSE_partial; + Element* const ptr_O; + StrideO const stride_O; + float* const ptr_LSE; + StrideLSE const stride_LSE; + cutlass::FastDivmod seqlen_divmod, head_divmod; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); + return { + args.ptr_O_partial, + args.shape_O_partial, + args.stride_O_partial, + args.ptr_LSE_partial, + args.shape_LSE_partial, + args.stride_LSE_partial, + args.ptr_O, + args.stride_O, + args.ptr_LSE, + args.stride_LSE, + cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), + args.cu_seqlens, + args.seqused, + args.num_splits_dynamic_ptr, + args.semaphore_to_reset + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{}); + Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape>{}); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const k_block = blockIdx.y; + int const batch = blockIdx.z; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + + if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + cutlass::arch::wait_on_dependent_grids(); + *params.semaphore_to_reset = 0; + } + if (num_splits <= 1) { return; } + flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; + int const offset = seqlen_info.offset; + int const seqlen = seqlen_info.seqlen; + int max_idx = seqlen * get<2>(params.shape_LSE_partial); + if constexpr (Varlen) { + if (m_block * kBlockM >= max_idx) { return; } + } + + cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); + + // Step 1: load LSE_partial from gmem -> smem + Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), + select<1, 0, 2, 3>(params.shape_LSE_partial), + select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) + Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); + GmemTiledCopyLSE gmem_tiled_copy_LSE; + auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); + Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE); + + // Construct identity layout for sLSE + Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m) + // Repeat the partitioning with identity layouts + Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); + + cutlass::arch::wait_on_dependent_grids(); + + #pragma unroll + for (int m = 0; m < size<2>(tLSEcLSE); ++m) { + int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh; + if constexpr (!Varlen) { + bidh = params.seqlen_divmod.divmod(m_idx, idx); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + } + Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); + #pragma unroll + for (int s = 0; s < size<1>(tLSEcLSE); ++s) { + int si = get<0>(tLSEcLSE(_0{}, s, _0{})); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} + if (si < num_splits) { + cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); + } else { + cute::fill(tLSEsLSE(_, s, m), -INFINITY); + } + } + } else { + // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem + // cute::fill(tLSEsLSE(_, _, m), -INFINITY); + } + } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + + // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2. + // We want these async loads to be in flight as we compute the LSE. + GmemTiledCopyAccum gmem_tiled_copy_O_partial; + auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx); + // Construct identity layout for gO + Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), + params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) + + // Precompute these values to avoid recomputing them in the loop + Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); + Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); + Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + int mi = get<0>(tOcO(_0{}, m, _0{})); + int idx = m_block * kBlockM + mi; + if constexpr (!Varlen) { + tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); + } else { + tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); + } + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); + if (idx >= max_idx) { + tObidh[m] = -1; + } + } + + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + if constexpr (!(Is_even_K)) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } + } + + Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); + + auto load_O_partial = [&] (int split, int stage) { + Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + if (tObidh(m) >= 0) { + Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); + Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOcO); ++k) { + int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (Is_even_K || tOpO(k)) { + cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k)); + } + } + } + } + }; + + for (int s = 0; s < kStages - 1; ++s) { + if (s < num_splits) { load_O_partial(s, s); } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + } + + // Step 3: load and transpose LSE_partial from smem -> rmem + if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } + __syncthreads(); + + S2RTiledCopyLSE s2r_tiled_copy_LSE; + auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx); + Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE); + Tensor ts2rrLSE = make_fragment_like(ts2rsLSE); + cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE); + + // Step 4: compute the final LSE along the split dimension + Tensor lse_sum = make_tensor(make_shape(size<2>(ts2rrLSE))); + Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE); + // We compute the max valid split for each row to short-circuit the computation later + Tensor max_valid_split = make_tensor(make_shape(size<2>(ts2rrLSE))); + static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + float lse_max = ts2rrLSE(_0{}, _0{}, m); + #pragma unroll + for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + int max_valid_idx = -1; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { + if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); } + } + MaxOp max_int_op; + max_valid_split[m] = Allreduce::run(max_valid_idx, max_int_op); + float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum_cur = 0.f; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { + float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur); + lse_sum_cur += scale; + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);} + // ts2rsLSE(_0{}, m, s) = scale; + ts2rrLSE(_0{}, s, m) = scale; + } + SumOp sum_op; + lse_sum_cur = Allreduce::run(lse_sum_cur, sum_op); + lse_sum(m) = logf(lse_sum_cur) + lse_max; + float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; } + } + // Store the scales exp(lse - lse_logsum) back to smem + cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); + + // Store max_valid_split to smem + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 5: store final LSE back to gmem + if (k_block == 0) { + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh; + if constexpr (!Varlen) { + bidh = params.seqlen_divmod.divmod(m_idx, idx); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh) = lse_sum(m); + } + } + } + } + + // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O + __syncthreads(); + int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))]; + #pragma unroll + for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); } + Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor(TileShape_MK{})).layout(); + Tensor tOrOpartial = make_fragment_like(tOrOpartial_layout); + Tensor tOrO = make_fragment_like(tOrOpartial); + clear(tOrO); + int stage_load = kStages - 1, stage_compute = 0; + #pragma unroll 4 // Already tuned for speed + for (int s = 0; s <= thr_max_valid_split; ++s) { + Tensor scale = make_tensor(make_shape(size<1>(tOrOpartial))); + #pragma unroll + for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); } + + if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0; + if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } + // We don't need __syncthreads() because each thread is just reading its own data from smem + cute::copy(Copy_Atom, ElementPartial>{}, + tOsOpartial(_, _, _, stage_compute), tOrOpartial); + stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0; + + #pragma unroll + for (int m = 0; m < size<1>(tOrOpartial); ++m) { + if (tObidh(m) >= 0 && scale(m) > 0.f) { + #pragma unroll + for (int k = 0; k < size<2>(tOrOpartial); ++k) { + if (Is_even_K || tOpO(k)) { + Tensor rOpartial = make_tensor_like(tOrOpartial(_, m, k)); + flash::convert_type_out(tOrOpartial(_, m, k), rOpartial); + #pragma unroll + for (int i = 0; i < size<0>(tOrOpartial); ++i) { + tOrO(i, m, k) += scale(m) * rOpartial[i]; + } + } + } + } + } + } + + // Step 7: Write the final O to gmem + Tensor rO = make_tensor_like(tOrO); + flash::convert_type_out(tOrO, rO); + auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), + shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); + Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); + GmemTiledCopy gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + if (tObidh(m) >= 0) { + #pragma unroll + for (int k = 0; k < size<2>(tOcO); ++k) { + int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (Is_even_K || tOpO(k)) { + cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); + } + } + } + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_fwd_combine_launch_template.h b/flash-attn/flash_fwd_combine_launch_template.h new file mode 100644 index 0000000000000000000000000000000000000000..11d422924b40ee82bfe2b9a2fd8d5d7ab57f0a78 --- /dev/null +++ b/flash-attn/flash_fwd_combine_launch_template.h @@ -0,0 +1,80 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 +#include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_combine_kernel.h" + +using namespace cute; + +template +void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + using TileShape_MK = cute::Shape, Int>; + using CombineKernel = flash::FlashAttnFwdCombine; + + typename CombineKernel::Arguments args { + static_cast(params.oaccum_ptr), + {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial + static_cast(params.softmax_lseaccum_ptr), + {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial + {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial + static_cast(params.o_ptr), + {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O + static_cast(params.softmax_lse_ptr), + {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + }; + + typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); + int num_blocks_k = cute::ceil_div(params.dv, kBlockK); + int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); + dim3 grid_m(num_blocks_m, num_blocks_k, params.b); + auto kernel = cutlass::device_kernel; + int smem_size = CombineKernel::SharedStorageSize; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + // We want kBlockM to be as small as possible to maximize parallelism. + // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). + static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); + static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); + ARCH_SWITCH(params.arch, Arch, [&] { + BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream, enable_pdl); + return; + } + } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else { + run_flash_fwd_combine(params, stream, enable_pdl); + } + }); + }); +} diff --git a/flash-attn/flash_fwd_kernel_sm80.h b/flash-attn/flash_fwd_kernel_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..b308d2d1b8892397db52434df9da4d4ab6cf5c56 --- /dev/null +++ b/flash-attn/flash_fwd_kernel_sm80.h @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include + +#include "seqlen.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdSm80 { + +public: + + // Type Aliases + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + static constexpr bool Is_causal = CollectiveMainloop::Is_causal; + static constexpr bool Is_local = CollectiveMainloop::Is_local; + static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); + static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; + static constexpr bool Varlen = CollectiveMainloop::Varlen; + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + static constexpr bool Split = CollectiveMainloop::Split; + static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; + static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; + static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool PackGQA = CollectiveMainloop::PackGQA; + static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; + + // Mainloop derived types + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1; + + // Kernel level shared memory storage + // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q + // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k). + static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) + - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))) + - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k))); + static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + struct { + cute::array padding_; + typename CollectiveMainloop::TensorStorage mainloop; + }; + // We want smem_o to line up with the start of smem_v + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); + // Initialize matmul objects. + TiledMma tiled_mma; + + scheduler.init_consumer(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{})); + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + // If there's tanh softcap, the scaling will be done before tanh. + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + int const bidb = get<2>(block_coord); + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); + + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.mainloop.shape_Q), + !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary + }; + if constexpr (AppendKV) { + bool tile_new_valid = mainloop.store_kv_new( + params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); + if (tile_new_valid) { __syncthreads(); } + } + bool tile_valid = mainloop.mma( + params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, + shared_storage); + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + if (tile_valid) { + // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, + threadIdx.x, block_coord); + } else { + // Write 0 to gO and -inf to gLSE. + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); + } + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_fwd_kernel_sm90.h b/flash-attn/flash_fwd_kernel_sm90.h new file mode 100644 index 0000000000000000000000000000000000000000..47b3817cd281d772e7366347682e97e4cb4cd861 --- /dev/null +++ b/flash-attn/flash_fwd_kernel_sm90.h @@ -0,0 +1,458 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cutlass/arch/grid_dependency_control.h" + +#include "seqlen.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdSm90 { + +public: + + // Type Aliases + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + static constexpr bool Is_causal = CollectiveMainloop::Is_causal; + static constexpr bool Is_local = CollectiveMainloop::Is_local; + static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); + static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; + static constexpr bool Varlen = CollectiveMainloop::Varlen; + static constexpr bool Split = CollectiveMainloop::Split; + static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; + static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; + static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool HasQv = CollectiveMainloop::HasQv; + static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; + static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; + static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; + static constexpr bool PackGQA = CollectiveMainloop::PackGQA; + static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; + static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; + static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); + using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; + + // Mainloop derived types + using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using BarrierQ = std::conditional_t; + + // Epilogue derived types + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + /// Register requirement for Load and Math WGs + // If we use cp.async to load K and V, we need more registers for the producer WG. + static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32); + static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160); + // If you want to print from the producer warp, you'd need to increase the number of registers + // Otherwise you'll get CUDA error. + // static constexpr uint32_t LoadRegisterRequirement = 40; + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; + + // Kernel level shared memory storage + // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v). + static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); + static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _1> { + union { + struct { + cute::array padding_; + typename CollectiveMainloop::TensorStorage mainloop; + }; + // We want smem_o to line up with the start of smem_v + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + struct PipelineStorage : cute::aligned_struct<16, _1> { + alignas(16) BarrierQ barrier_Q; + alignas(16) BarrierQ barrier_Qv; + alignas(16) cutlass::arch::ClusterBarrier barrier_O; + alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; + alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; + alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt; + alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new; + alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new; + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + } pipelines; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + + using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; + using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; + using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt; + using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew; + using PipelineState = typename CollectiveMainloop::PipelineState; + using PipelineParamsK = typename MainloopPipelineK::Params; + using PipelineParamsV = typename MainloopPipelineV::Params; + using PipelineParamsVt = typename MainloopPipelineVt::Params; + using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + + if (warp_idx == 0 && lane_predicate) { + shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + } + shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); + } + + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + PipelineParamsK pipeline_params_k; + pipeline_params_k.role = warp_group_idx == 0 + ? MainloopPipelineK::ThreadCategory::Producer + : MainloopPipelineK::ThreadCategory::Consumer; + if constexpr (Use_TMA_KV) { + pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_k.is_leader = warp_group_thread_idx == 0; + pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; + } else { + pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; + pipeline_params_k.producer_arv_count = NumProducerThreads; + } + + static_assert(is_same_v); + PipelineParamsVt pipeline_params_vt = pipeline_params_k; + if constexpr (Use_TMA_KV && !SameHeadDim) { + pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } + } else { + if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } + } + + MainloopPipelineK pipeline_k = [&] { + if constexpr (Use_TMA_KV) { + return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); + } else { + return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k); + } + }(); + // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); + MainloopPipelineV pipeline_v = [&] { + if constexpr (!Transpose_V) { + static_assert(is_same_v); + if constexpr (Use_TMA_KV) { + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); + } else { + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); + } + } else { + PipelineParamsV pipeline_params_v; + pipeline_params_v.role = warp_group_idx == 0 + ? MainloopPipelineV::ThreadCategory::Producer + : MainloopPipelineV::ThreadCategory::Consumer; + pipeline_params_v.producer_arv_count = NumProducerThreads; + pipeline_params_v.consumer_arv_count = NumMmaThreads; + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); + } + }(); + // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then + // the producer WG will read from pipeline_vt and write to pipeline_v. + // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. + // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers. + // However, the thread role isn't used in the pipeline implementation. + MainloopPipelineVt pipeline_vt = [&] { + if constexpr (Use_TMA_KV) { + pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); + } else { + pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); + } + }(); + + PipelineParamsKVNew pipeline_params_kv_new; + pipeline_params_kv_new.role = warp_group_idx == 0 + ? MainloopPipelineKVNew::ThreadCategory::Producer + : MainloopPipelineKVNew::ThreadCategory::Consumer; + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; + pipeline_params_kv_new.num_consumers = NumMmaThreads; + auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + if constexpr (!SameHeadDim) { + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } + auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + + // The pipelines for AppendKV and main attention are different, since e.g. main attention + // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load + // KV_new. Since the pipeline states are different, we have to manually sync to make + // sure the two pipelines don't race when accessing smem_k and smem_v. + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); + int work_idx = 0; + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + if constexpr (SingleProducerWarp) { + if (warp_idx_in_warpgroup != 0) { return; } + } + if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } + + cutlass::arch::wait_on_dependent_grids(); + + // Load Q, K, V + for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + SeqlenInfo_t seqlen_info{ + get<2>(block_coord) /*bidb*/, + get<0>(params.mainloop.shape_Q), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary + }; + if constexpr (AppendKV) { + bool tile_new_valid = mainloop.load_kv_new( + params.mainloop, pipeline_k_new, pipeline_v_new, + smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx); + if (tile_new_valid) { + // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); } + cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::AppendKV) /*id*/); + // if (threadIdx.x == 0) { printf("Producer: After sync\n"); } + } + } + auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + }; + // pipeline_vt won't be used if we don't need to transpose V. + mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, + shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx); + } + mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + + // Initialize matmul objects. + TiledMmaPV tiled_mma_pv; + + PipelineState smem_pipe_read; + PipelineState smem_pipe_read_new; + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v + // (like in Cutlass's gemm) because the read and release pipeline states are always the same. + + scheduler.init_consumer(); + mainloop.mma_init(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + // get_next_work will be called before the epilogue + ) { + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + int const bidb = get<2>(block_coord); + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.mainloop.shape_Q), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary + }; + if constexpr (AppendKV) { + bool tile_new_valid = mainloop.store_kv_new( + params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new, + threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord); + if (tile_new_valid) { + // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); } + // We need this sync so that the gmem write from the consumers is visible to the producer + // that might do TMA read after that. + asm volatile ("fence.proxy.async.global;"); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::AppendKV) /*id*/); + // arrive is enough, we don't need sync. The producer will sync, which means + // after that sync we're guaranteed that the AppendKV pipeline have finished + // loading and consumer smem_k and smem_v. + // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } + } + } + // If there's tanh softcap, the scaling will be done before tanh. + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax softmax(softmax_scale_log2); + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); + bool tile_valid; + if constexpr (!LargeHeadDimV) { + tile_valid = mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { // mma_pv might not compile if !LargeHeadDimV + if (warp_group_idx == 1) { + tile_valid = mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { + tile_valid = mainloop.mma_pv( + params.mainloop, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); + } + } + // Do this here before the epilogue so that the next tile is ready to go. + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); + if constexpr (Split && Varlen) { + if (!work_tile_info.is_valid(params.scheduler)) { // Last tile + cutlass::arch::launch_dependent_grids(); + } + } + if (tile_valid) { + // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, + threadIdx.x - MmaThreadOffset, block_coord); + } else { + // Write 0 to gO and -inf to gLSE. + epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + } + } + epilogue.store_tail(); + } + + } + +}; + +} // namespace flash diff --git a/flash-attn/flash_fwd_launch_template.h b/flash-attn/flash_fwd_launch_template.h new file mode 100644 index 0000000000000000000000000000000000000000..b8af2977f116bb1ff6ad5bb2d354fb02a6a4ba3b --- /dev/null +++ b/flash-attn/flash_fwd_launch_template.h @@ -0,0 +1,223 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" // For device_kernel +#include +#include "cutlass/cluster_launch.hpp" +#include "cutlass/kernel_launch.h" + +#include "static_switch.h" +#include "flash.h" +#include "tile_size.h" +#include "tile_scheduler.hpp" +#include "flash_fwd_kernel_sm90.h" +#include "flash_fwd_kernel_sm80.h" +#include "mainloop_fwd_sm90_tma_gmma_ws.hpp" +#include "mainloop_fwd_sm80.hpp" +#include "epilogue_fwd.hpp" + +using namespace cute; + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); + static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); + static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen"); + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; + static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor; + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + + // Can't use structured binding since it's not compatible with constexpr + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); + static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); + static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); + static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); + static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); + static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); + + using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_PV = cute::Shape, Int, Int>; + using ClusterShape = cute::Shape, _1, _1>; + using CollectiveMainloop = std::conditional_t< + Arch >= 90, + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 + >; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + + static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + std::conditional_t, + flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> + > + >; + using SchedulerSingleTile = flash::SingleTileScheduler; + // If Split then we probably don't have enough work for PersistentScheduler to be useful. + // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better + // since we'll avoid launching a bunch of thread blocks that immediately exit. + // On Sm80, noncausal persistent seems a bit slower. + static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split)); + using Scheduler = std::conditional_t; + using AttnKernel = std::conditional_t< + Arch >= 90, + flash::enable_sm90_or_later>, + flash::enable_sm80_to_sm89> + >; + + bool const is_varlen_q = params.cu_seqlens_q; + bool const is_varlen_k = params.cu_seqlens_k; + bool const is_varlen_k_new = params.cu_seqlens_knew; + int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; + int batch_q = !is_varlen_q ? params.b : 1; + int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1; + typename CollectiveMainloop::StrideV v_strides = + cute::conditional_return( + make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0), + make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0)); + typename CollectiveMainloop::Arguments mainloop_args { + static_cast(params.q_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q + static_cast(params.k_ptr), + {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K + static_cast(params.v_ptr), + params.dv, // headdim_v + v_strides, // stride_V + static_cast(params.knew_ptr), + {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new + {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new + static_cast(params.vnew_ptr), + {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.qv_ptr), + {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv + static_cast(params.rotary_cos_ptr), + {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter + {params.rotary_dim / 2, _1{}}, // stride_rotary_cos + static_cast(params.rotary_sin_ptr), + {params.rotary_dim / 2, _1{}}, // stride_rotary_sin + params.is_rotary_interleaved, + params.page_table, + // if page_size is not set, avoid dividing by zero + {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table + {params.page_table_batch_stride, _1{}}, // stride_page_table + params.scale_softmax, + params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, + {params.q_descale_batch_stride, params.q_descale_head_stride}, + {params.k_descale_batch_stride, params.k_descale_head_stride}, + {params.v_descale_batch_stride, params.v_descale_head_stride}, + params.window_size_left, params.window_size_right, params.attention_chunk, + params.softcap, + params.num_splits, + params.kv_batch_idx, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, + params.leftpad_k, params.seqlens_rotary + }; + typename CollectiveEpilogue::Arguments epilogue_args { + static_cast(params.o_ptr), + {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O + static_cast(params.oaccum_ptr), + {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE + static_cast(params.softmax_lseaccum_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial + params.h_k, + params.cu_seqlens_q, params.seqused_q + }; + + int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); + int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{})); + num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{})); + typename flash::TileSchedulerArguments scheduler_args { + num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, + params.h / params.h_k, + params.seqlen_q, + params.seqlen_k, params.d, params.dv, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, + // params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + }; + + if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + + int device; + CHECK_CUDA(cudaGetDevice(&device)); + typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ + mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args + }); + + dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); + dim3 block_dims = AttnKernel::get_block_shape(); + int smem_size = AttnKernel::SharedStorageSize; + // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); + // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); + // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); + // Get the ptr to kernel function. + if constexpr (size(ClusterShape{}) > 1) { + void const* kernel = (void const*) cutlass::device_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params); + } else { + auto kernel = cutlass::device_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, + Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + } + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; + using T_out = std::conditional_t; + CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { + static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; + VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { + // Only needed here to decide if we should use cluster + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; + + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); + }); + }); + }); + }); + }); +} diff --git a/flash-attn/flash_prepare_scheduler.cu b/flash-attn/flash_prepare_scheduler.cu new file mode 100644 index 0000000000000000000000000000000000000000..7093fff32b673c06bd1dd7b09bcdf3c4deddaa53 --- /dev/null +++ b/flash-attn/flash_prepare_scheduler.cu @@ -0,0 +1,124 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include "cutlass/fast_math.h" +#include "cutlass/barrier.h" +#include "cutlass/arch/barrier.h" + +#include "cutlass/arch/grid_dependency_control.h" + +#include "flash.h" + +namespace flash { + +__global__ void prepare_varlen_num_blocks_kernel( + int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, + int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, + cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, + int* const tile_count_semaphore, + // int* const num_m_blocks_ptr, + int* const num_splits_dynamic_ptr, + bool enable_pdl) { + + static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; + static constexpr int kSmemSize = 1; + // Assume that there's only one block in the grid + __shared__ int total_blocks_smem[kSmemSize]; + + // There's only 1 block in the grid, so might as well start launching the main attn kernel + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } + + if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } + __syncthreads(); + + if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } + + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + + auto get_num_m_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen; + if (seqused_q) { + seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; + } else if (cu_seqlens_q) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_q_static; + } + seqlen *= qhead_per_khead; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; + }; + + auto get_num_n_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; + int seqlen; + if (seqused_k) { + seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; + } else if (cu_seqlens_k) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_k_static; + } + int seqlen_new; + if (cu_seqlens_k_new) { + int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; + int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; + } else { + seqlen_new = seqlen_k_new_static; + } + // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } + seqlen = seqlen - leftpad_k + seqlen_new; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; + }; + + int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; + int bidb_start = kNumBatchPerWarp * warp_idx; + int num_m_blocks = get_num_m_blocks(bidb_start); + int num_n_blocks = get_num_n_blocks(bidb_start); + + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } +} + +} // flash + +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, + int blockM, int blockN, bool enable_pdl) { + // Only support batch <= 992 (32 warps, each with 31 batches) + int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + // params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, enable_pdl); +} diff --git a/flash-attn/heuristics.h b/flash-attn/heuristics.h new file mode 100644 index 0000000000000000000000000000000000000000..031ea44a0b399c7a6db4e38c273839b08eba1bec --- /dev/null +++ b/flash-attn/heuristics.h @@ -0,0 +1,59 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { + // If varlen, we don't actually know seqlen_q but only max_seqlen_q. + if (varlen_q) return true; + // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM + auto round_up = [](int a, int b) { return (a + b - 1) / b * b; }; + float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM)); + float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM)); + return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency; +}; + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + // However, in the case of super long seqlen where each head of KV doesn't even fit into + // L2 (we assume that L2 size is 50MB), we want to split. + if (total_mblocks >= 0.8f * num_SMs) { + int const size_l2 = 50 * 1024 * 1024; + // Only split if there are enough queries to go over the KV at least twice + // Don't split if causal + if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { + return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); + } else { + return 1; + } + } + // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. + if (num_n_blocks <= 4) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + float n_waves = float(total_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..8cf5ed22c2ab6f2b32ffa89fffce0f7227d91901 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..54fc817c9ca4a9cfb9bde98593f3c4bf33a6618f --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..01b28afbc0cc82b159e60719fb8777d71f61e50b --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f4b207c53e4cec0edf0d50166fe2fd8c1a5d120 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5a4d6478cb24b4ffebd931315fcf6d72475d10be --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim128_bf16_sm90.cu" +#include "flash_bwd_hdim128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..6833e03dc841c25c4c1dbfab4a0fce5a0f28b4a8 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d02d370b5032615979ca80df1a272caad1e94800 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7981d11c2a267df33925c685952090e29f039ba3 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b474ace825b092c7805b4c94d033cba00fbc819f --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template<> +void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cc52974e4c7f70384faecec0c21a8a7260bc0a2d --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim128_fp16_sm90.cu" +#include "flash_bwd_hdim128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..41b4b01086a1aa809e485d5efe094cabd7d884fa --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e81ff7346a5acf9f16b786ae72b6e163eda0e79 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..15265f25b26c03ce6e8fb3e76a949b55eaac18bd --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0a363983ec586a9e416590339dadec8c698cad62 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1504a28475d09029661aeaa73cb3e85f0b355074 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim192_bf16_sm90.cu" +#include "flash_bwd_hdim192_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..77ab70a00c0fb33a4c5baf2404a7097fa98bbaa3 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..67a99dfc5d37c1034b2c9df8a5ad64e4b2784d08 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..c1a72b0bf27b47f127eb8c110d46cb739d7ea78e --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..75291357b89dfbb1989937a7e3c4fa9c850b0841 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template<> +void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..fdd34babb46f0e4bc5f0aa2c8ba5e07ea33b8b7d --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim192_fp16_sm90.cu" +#include "flash_bwd_hdim192_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..fa0bc55f9a6f995eb7b59a8698edf367eba6bfcd --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..937b03c5051f28b5f1c6f6595bf29d331cd296fb --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..43ce3b615548d68edb40a3df9b2b9c6ff8eb39ae --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b789174b1513a3793a2cac3a4a4eb0bb61ae5d3f --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d68af3d2764b0b0e53c6915d7e47dc7e0a4e697 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim256_bf16_sm90.cu" +#include "flash_bwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..c6aa2475abf7c05b548ca8a1ca40e203d8bf6aa3 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..970f4d6252f679e841d105eccc88a0f1eb2dee87 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ba90899624df01f8e1173baf28eecbee9bae9450 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<80, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<80, cutlass::half_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<86, cutlass::half_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1a16a4d428c76dcd1cb29d6fb948afa42f322b9f --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template<> +void run_mha_bwd_<90, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256<90, cutlass::half_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cd471362dfcc43984ef5b9e08c9b250388900784 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim256_fp16_sm90.cu" +#include "flash_bwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b242432c662e52defd583f4ab676231903063e2 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<80, cutlass::bfloat16_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<86, cutlass::bfloat16_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6a90f01b86a5e22f9bbef3f5ec7b5d692e1450df --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<90, cutlass::bfloat16_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..fc6116b13e797248cd9d1d1eaf41da8f18181727 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<80, cutlass::bfloat16_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<86, cutlass::bfloat16_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..481c1a7e895c6136a5061cfa3776019e642ae2bf --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<90, cutlass::bfloat16_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..414bca23c93bd1a392237ec5a9c33daa8a540044 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim64_bf16_sm90.cu" +#include "flash_bwd_hdim64_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd0205076c0e23b33b9a32f21294cccdf14b486d --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<80, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<80, cutlass::half_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<86, cutlass::half_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a927c6322bbd6c06fd035c0225937263b723535 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<90, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<90, cutlass::half_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..401e817bdd38274fcb20c722ce208e64ca41ad36 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<80, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<80, cutlass::half_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<86, cutlass::half_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d726b62c1fffa623570a1e124a6eb9eda58f2c5c --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template<> +void run_mha_bwd_<90, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64<90, cutlass::half_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..10697f012d82724bfce5db1c36d420b30bf28cea --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim64_fp16_sm90.cu" +#include "flash_bwd_hdim64_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..326a796ed257057d3e86b8e0f30e6852ae1961fa --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<80, cutlass::bfloat16_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<86, cutlass::bfloat16_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f63535737a3f198f28673f33b108ebf67663f4ea --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<90, cutlass::bfloat16_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..63522b87f9163aa73530acddab50c2059e763b6b --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<80, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<80, cutlass::bfloat16_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<86, cutlass::bfloat16_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b16a8d376d259f3ecf158de3708b2d7704a79d06 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<90, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<90, cutlass::bfloat16_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f6593e1801a05ecbe5c79994a1155745d418df13 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim96_bf16_sm90.cu" +#include "flash_bwd_hdim96_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..92ff4ad1ad9a72e941a9e2b8bdc6160eb28ecfe6 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<80, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<80, cutlass::half_t, false>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<86, cutlass::half_t, false>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..56ab6760d68cea3655273af8dc5ba42a571ddf25 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<90, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<90, cutlass::half_t, false>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..0896abb5a6b399541555aa2589d1886c5f39a798 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<80, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<80, cutlass::half_t, true>(params, stream); +} +template<> +void run_mha_bwd_<86, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<86, cutlass::half_t, true>(params, stream); +} +#endif +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d92efc0ce94642ca24f533ef6ff414be16a6d9b --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template<> +void run_mha_bwd_<90, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96<90, cutlass::half_t, true>(params, stream); +} +#endif diff --git a/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..82265ca9edae8df97daa3ecefa5c211bec734a67 --- /dev/null +++ b/flash-attn/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_bwd_hdim96_fp16_sm90.cu" +#include "flash_bwd_hdim96_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..affc7a4dd96b51efc4a3f5e452731a6d94ac33d0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7e13614bfeadbf03fe6f7b6d74d425c010f4bcd2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..670041341bcf3f883cd9933801434faa52209228 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..f315fbb45454d3d6ffb2f0471e0e8fa5ad27b575 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bde3024a4a68ee6ce74a98b48e450c2cdbd12078 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..40afcfd68b17a4fd73e7053aae7d7d2b83a49c46 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_bf16_paged_sm80.cu" +#include "flash_fwd_hdim128_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..2724463e621137c98ea37b303c3f70307fada8ae --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a38a1d5cf33d3839adff7452e5a8ff609cec2638 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..284eeba18235c2acf4b6eccce2f670ea19e27495 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0c40ddba8fe6814103fefab86653c08941ce8003 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..5bed85a0ddc4cfabf33abc1f9147f699ca5add97 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_bf16_paged_split_sm80.cu" +#include "flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu new file mode 100644 index 0000000000000000000000000000000000000000..4fb8f71d01e8faf7648ac1988dd4f908870f25b5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..cc89c4d5d2559acc5a53e9d163494920af7140ef --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a236b712c4d0c4660ee6113f5fe39d661537498 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8449104c5aa188ab4416ab85ef9d3f7383605e7e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b152b90bab7c9f50bce1cd5e3d6949182308a6d9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8cc4fed1739eb7b50b8479f7aa8ec608d5667816 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..eee51630db1e8c1d65ee5453c8584d8f25b8ac02 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_bf16_sm80.cu" +#include "flash_fwd_hdim128_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..1db3f1e6d803ed909f727a05ca81fc7e74c3b5af --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..9b3e294f1b3766502cd4c69cf2ddd89d3424b188 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..07bd687fc344c96299d80a6a991e1633ef03b0f8 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f44833b10da9db301a67de6247d4ff36b530e47 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..f51f96d92665a81732981ae15637a9b430a29c42 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_bf16_split_sm80.cu" +#include "flash_fwd_hdim128_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..9f95ca29f6be2cebf53b454ece8eea0bb418bc3a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ad97737d4f3d2c02521338c2528d2670ced4a17d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d77d37ec0414df8f00d8ca3ad4fcb4abc3db6e1f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ae05c7ce5f0a87b7081478571d466bd4ec5b7a7c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc52a9f356fd3640b7e4c55995026758d4508f2c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..480d485d06908cd945e3b90038f471701478d6a0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d3da5f4e6656c697f6da20e6dfb3d5e0bfa78d08 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1c1c2d8207fd419136dbbc98429e8d0f32048f88 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..371d933e3e1e8cdd763644cb451aa036aa666497 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7491148dcdebc67ed4f01ddcfa8ecebc86e97b9b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d04159a62a0aeb11465ca77be37057aaef1b78ac --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..28ad6c149637ab70a0f461b146758df6b119d6cd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7afb267e3ebf9e4b42b695c07a80aafacce23f43 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..69758584cb63da1305a294e555cff087e57bd716 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3be45956bb4febcea1ec0d835da3c774664608fa --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..29b97d0e149b8ee261a2b08d9e00dee1bf40d70c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_fp16_paged_sm80.cu" +#include "flash_fwd_hdim128_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..698095dad6a0a37a06e708b92fbb4bb62716e3cb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..16d443a9ad170e6471f71caf44ea5f20f0b17f00 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e8f6af71bdabf35553d2e7c972f84e73c09e7b5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4ec688861124107eceb82634baa8698d19fa1578 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b745a57a2042d74ab701635ce9e54897f353937d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_fp16_paged_split_sm80.cu" +#include "flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..670b5952d9d1e4a65f0d3f7999befa19ce33d1c0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b9778dc92e1b0c39197fa54ef7bf3de344e25505 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..446e917c79536e41113457299df62384a312faf6 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..fd62a2c54352cdcf60dda75e0a7d847a9d2664dc --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0a397f4acf2ab3a1c69aaf643b00d5ec9f65c1dd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..2292272ac2a45410117b72d2d1f9c8547e70ce4b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_fp16_sm80.cu" +#include "flash_fwd_hdim128_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4d3c553e296a5cc9fbaec634a7916af400fa9130 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..77621846ffeb8ee66e4cb1a225c5ce85759ad905 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d217ac273384dc2fa989e908ec02bc28d29aa72 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b6430abc2ff56903f969d6f53ef88e0ff0604c9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..58891d23eb60fe64a533a05b6cf6e888b368b2f2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim128_fp16_split_sm80.cu" +#include "flash_fwd_hdim128_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ea1e266f8d4cd2621be685f4cd2aced78a8dca01 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2d7488fefe29e627e4c7c9192db54b764a74e7ba --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8718571e30c2b1177766a0a81b0b8a86d361d6da --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7dfc18fc1ebace9b4c6e8527b4f4a1f3a70ec7b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..935f5a0fe6084fc3dce8bf93eaa582b6cf59b32f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3f4d858ff572ef21225c2a41bf41b3891806a37b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..54d720efeb32b028401a6d619213ce6f31aec598 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b9b93af4fc50033570b8345c86b7731fc506fb85 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..39d9167b9f19817a25ed02f685c78af4a54e322b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0f86458012a8649975a0d611234749ab116fbb7e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd6f4df8f696c154e1d8a735d6b7b2c1ebce8087 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1824b86c64ce3786316879b9b02eb7cc3d5c6e9b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..87dd01725a540f59c29bff4f89b1eca9cf7f1926 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6594d56012392b19ae445d2f6e72501f81a5b837 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d7dc84ebc1c8f009117f97aa6b941029f8560aca --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b9d6e54cbed3fa19cd04ae6460355b9f5b6bad62 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8c47652ec15f8c395fe0662604f490a805291d2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..32d17c7665d612a0524e0cc47cb971c8f12c903f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..365017c256d027bc48a01077db434a17c0bed18c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..82cfdf040b0dd0279a4245bf1c4ee61bca855f37 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f3254936a47f004f0a261138645e3ff4d2a4dcc0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..931a6dbf86985659cda2fb7bfb83c8a3ec83a398 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5c8877a756dd56a03ffaa3ce9dec728a8a65f752 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e230ab084bf830cbabc4076526355e63452f7fb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..03716c862372e3d10310d8c208bc7d5f38a6d430 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c66c9552e994ecd8dc7b6e0397138fdbb04cd1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e5e0ec47db1ae18706d39fb992d38de0085ab8bb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e4411b5db32833541e2e1aeab5526e9415a78c74 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..157ed06dddf66ff9cd465b5ac700d7c319007339 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7ef5adc9e85d9ef4ff3dc7fc2e78e693dbf6cdb9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bf8386b82976f8ca6ada1118e6f4073ae4624413 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..cbc6f988424eb5c10494f0133b2a148197f6171a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d5aa15b5c8ca63eede4368b0bd478a3f9e5760aa --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b8593612df3d3054d7bfb52703552a718d58b384 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a03514d919b075cf4d092a878e6cd22d05294441 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..a7f3f9d0fe4d247cf8d16f43dcc44476cb40bc85 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_bf16_paged_sm80.cu" +#include "flash_fwd_hdim192_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..df547749e93a135c6c17ba1a1b1615b001dde447 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1ddb191620981189062f3d745e159ef4c66e9df8 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..cefffcd2169977a9c1cf54ba5e517b99c276422e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3d4333b9e1f0fb136edaf54378c5dd688292ccba --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..298b0cabafc3ec75f84ae7e4e3484db325267e5a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_bf16_paged_split_sm80.cu" +#include "flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..35a2abef8c9e02931e6a8b4847f614fdf3ac60d4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..99e34ac0bfb09fb5b4f386d604122c42bff2e38c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ed1cf22d5c410ac8f47e82314fd70f38e18a152c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4527d9a27931e4d244b25e2206ad8c1861d8ea7c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..41fcf8001701b6581753b9184a478967826cdbe6 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8b44171204e33fb1cc2478e6a3c0615fc07462f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_bf16_sm80.cu" +#include "flash_fwd_hdim192_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..704cbcb337e95ac6f64b10e69be538a719579f2d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e0ea082156bb79dad4694b06eb3ece76c4d3ab7f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..a9c00408a8b2c919a270fd1db292d1e6575fd3a0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1497e7aa8430b300f003335d7b9fa90e8e40fc1a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..95adb9e2b4f99634b6623cfde868d74d1064777e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_bf16_split_sm80.cu" +#include "flash_fwd_hdim192_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c66ea9baca1b3935de59247efeafe2f20b48632f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a7e472b478b137bf9819c5cb577cc4f000e3b964 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..9f090aeeda8b46ec40c3aae73e1986288e868eb5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2205168a67f4cdbea59908f30076937f5b5f0372 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a01898b5604a68c52bbeb300c46bbaaaa6026b4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..888e241a9f1f1f6ddec5cf8247dab23fe16e089e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a6bde7a39fbd40d2d528cb0ea242e9140d09bdf --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3d315187b2d4585ede8e6665421e6d3ed743036f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3c3d0938034d39d70293feff0b5e4e2f22701748 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4ca103566d673d97f7a1324eecb6795105ced521 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..16debf27799fb8f35873208b60a77e0e39a6b596 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..43c2615718e27ff689ff6c19e774732e77625531 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d9d483838f11c228ed21c3e7baf8883cca8811e4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..70543998d948f776cf420d0cb6525a087e0460cb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c30c7e3b8b9022cab1e486032ed16b6220b794ec --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4bc8c4264b07fb1495c9db298839f3a4e95a1ede --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_fp16_paged_sm80.cu" +#include "flash_fwd_hdim192_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7ae26e69c96bf2a4fe3f3d8f9cab8c93e490dbb1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..155b5a539fd9345905e3fdb7162898bc8fdbca80 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..3e6173c31c23a290ee1f13d97f2bfbc22798bdee --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e1e3191a202e0033c3e7079a3742bf12ff22473e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..6316f81e7b9cec8b08f58981cc5f5d279b958eea --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_fp16_paged_split_sm80.cu" +#include "flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..8272ecb76cb40165e09f0a428cf89f5ace6ce456 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..74606c3937381c01a0a01eea84825a0e0cc2510b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..89a58502b375a91a441cc1131d1ecd7e1ff0673c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b13373806a1a008a9ab20e89bd0f065864f3c955 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1335fad7f2b66c58822bc8be52134aabdc0099e1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..cc59a01dae9aed4d8a868c79c7aa7b4328595b41 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_fp16_sm80.cu" +#include "flash_fwd_hdim192_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..18c31bdfc0ea9dc93230d82826cd86782aaea68b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..18a5603cf1f51f5c97469ba94cf3adb872c90e7c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4e99c7db0276402198aa4eabd1f5228599907a28 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..82f8204aa66609bba5b845979658e840ce18eb27 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4314bc82236e0e8897dbaba5aeadced086fd665b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_fp16_split_sm80.cu" +#include "flash_fwd_hdim192_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cb851a77110c7d0eb8b43840a083fc795139f707 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ae2871c16558ebe5c67d7321be8a31811ae651fe --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ed24fbffef9e2e53e2e490cb034457c7a1e5bdd1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ffca9c7f8fe36e11efc1cebe9de6cabb4dd493c9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..57a06bd6e66002cf7f91736cfab6c7f0505c08dc --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..47db0cac3dd20bc8e0ce98d1c02f101166d169ce --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_bf16_paged_sm80.cu" +#include "flash_fwd_hdim256_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ccdcf21e492929ca83e2c33a3970fe7b11314713 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c2bc7787765e1929bfe3cee0f1cc6aadbac0c449 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..6bba953fc69863996ae757310c10b1c5b39354b7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..25c96174c79bf203dea2bb301dd4ebbdbd5a6457 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7f62ded549444ab65960a3b9279fb735a22f30ee --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_bf16_paged_split_sm80.cu" +#include "flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..f172239e5b95686b51dd98271d58769108f4055e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..9dde6adb04b6c02c0b69d9baa12d082d99b1b06b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2317adef8c53ab5615782ee867ef2dfa0bdd5ad2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b9b3b74867e28f5676e827010c84e636b3f2e346 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c57a5a30abbd8d820b9ae9aef1b4fa14fa56e4b4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..12851474031205db74385d6fbaaf097b6cbf548b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_bf16_sm80.cu" +#include "flash_fwd_hdim256_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4f59a6aea92267476a8c84eedd9ce9abd9f51a71 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2c2de1574ac08ef01b7353be8d0575f35997269a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..0dbd062c79f8e503ceb8ee99cc2124930c753808 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bee54c702de1b1b2727bd58b55eb600af8f52785 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..2007abeeaba5ec81cf6eaae99444e09cc549f48c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_bf16_split_sm80.cu" +#include "flash_fwd_hdim256_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c02e6833494d0ba4f9795b75262e06a8cfd36215 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..02b50b98b8d1a31492b772683c5a4b566e65921f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6599de63bbd654300a2833eebd2f00bfef3d0b1a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a1cdc775cbb78a85ba25016457ebc68dda3f6241 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6d01be60f58ab0a6fc1bf23f48235a196d7da29d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..968bbf36f838107dc345f65f4d6532ebce037699 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d564a622111d492e3bb2f6017d1671dc4dde1615 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cb5bccc176cb570d3be44d3820fce4459300a821 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..146a7bc3430bccee015f6e3dbc120e7e6064fac5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a195e0931c00565f172c144c4fc2b9cf9f7f9e43 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..045fc71bedb78c444d2cccd4eba49e4b6758d666 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..a31da2eddf4ea3ccc95084e5a9779494c5022013 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7382b58a2313acba7ffce57acd16acd0070b135f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..87ca31ce9027333926b94095d6f4ef3745f8a602 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..60f4d6ebbf6ec4dae3b1d50a47f06d3d4151bef2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..593bba5280a61c75520dd84f2ae616c2126fb43f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_fp16_paged_sm80.cu" +#include "flash_fwd_hdim256_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..e0d5d318bcd26d9370e8d0e4ac556e32270e9c01 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..dec7db046bd43d2fd30970e1eff42f47ea5df83f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b71f4352260581968c725181779601f71771e8b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..08fc989af8b1feaf84797b4294c61a23231c922d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..f1adde60be67e2646eb683195808f4295785a83e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_fp16_paged_split_sm80.cu" +#include "flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..2cc8b5b86d443434af5035210ce09a60ffcffae3 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..644e268469fa4ed900b56dd203220c9155ca0774 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1ebcec8b3fef8ebd9f2ea03e729500e611f5773b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..780ade7f6b8451d145ae3ecf0ecf1c75fae842a9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bfcffe2a39a9f516dc79d06aea66ca7aff43ae08 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..5724ae3309c6b0353abc058d0a26b4da8784b675 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_fp16_sm80.cu" +#include "flash_fwd_hdim256_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ba4ba78ad49ea806df748e538010b83c081987d7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f04260ba4f1acfb2049b7f623e9b667ca0c97fdf --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..33c78e53059d96fc7a0d556ca80633debc504634 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8388420921ee70bd1218e272be90fd5180baef22 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM256 +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..fe6b77bb8d23cab3d3e764728966c0459c826bdb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim256_fp16_split_sm80.cu" +#include "flash_fwd_hdim256_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d037153cbbfaab52cf3a64010e9f068f525ba1b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c62e0b8d822728f7bb0b6e596e20daf8365bc273 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5e22d67f70027c14ce7038ebe9ed1131b56b1e1f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e005b3f018297375570e29729389a2994071064 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..96c4f55afdbdc9baa8b22b952bb1417fcd2ee7cd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8a92fe291ee1074590274e52742bf791137c3fff --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f47cb326674059880736e25168dfb1cca0abf97b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1915feb0463cf7f0d0107d0e9b17fb69e97e2fe4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..fbc15776610254c568e1a36b7a3d9c8c49583de8 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..88445691ffbad68dce341979a4be1a506be9d2bd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7d051a34d37a955c0393497b21a2174b0d99ff7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c83c1741d4f91d383f63d3445e4ccb9b407ae191 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2e06c89a8c7c9ce1fca71611da21abb7e9b23113 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..46479ec15e14c4c1743d52cc94de58eb67eb154a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..18681ec42b416129faa090292f1edce49d3801f1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d2245aa136aa026b8b39cf5448f32dc930f51e7b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..022cdd395768e2d029927bbd954020ab698247a2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..67a324d52e805a0856cda8fb154efaa876b2180e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..664f88dbfce0e9ae267764bfa642a1e85353779f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6bd6b9ab38f6b632b2b16c97643c3b099cb95c1e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2f4ceaaed53398201d5666028a9e0015e14426bf --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5fd59af3486edc4115e8c6e0f745dccbe60d6d10 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e0f885b0f72ae4858470b72930b9ccea8ce8eb40 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6dcda019627eeb34f7622ea4fda6ce9461b2bb95 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5d20be6d2a7428fc95b7be228cd50b46b9f2eb9b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..47463a7151cbfdcbe3a59f4c009fc249886a38aa --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..622b5533ce85fbd8ba7818b23b0fdbecdd70c2f8 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c83f44722cda20dea1321a1ca9de61bc8d278e79 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5c9130f86481367208c67397eac23c9cb7ace3b5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a152022cb650340d0b409f419a5ca587054344d5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef05aa2038dd46a2bc6b5f43195eca342ac2bd6f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..19fe6d94f7dc6776c8124deaec9db7dd50dc57bb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6eb2d3d134bc9320a256f1b06164dfca347a7119 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ffbc9982122290c2f296c4c34c2bdc15f3471f90 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3d35075b48d31bfcdfd43e7407573d981da2f949 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c2af33cf533228efe2bbfe89b82b1ad4578e131a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e07547c92d0391f3436103188c8dc56a1acfcfe5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1a04eb01f5e9143020bbd562bd477588e6740881 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..da9afc1157167534a6b60718cfac914b977a6817 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5e63a15515fad74539074fe8781f4a63f0bc1742 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4134d7d80bbd84711c0920e68ff22bd7f86269b0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..11e3503b0d9263f11bf1f656e86af0ef4e95f90e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..67e39bd7371b72956ff47843220a07566205dd5a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..c37844daa562b8d143223db22981c2733ffa9cb6 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f0c40e2f89fe9a41826447a4806788d149bd15ed --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..850a47e25c55352d371818a2a9cfaf7025ef1db8 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_paged_sm80.cu" +#include "flash_fwd_hdim64_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ed9694908c466b3bd153418a672347359deb640 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a16aae66c00c5fc8659dffa2657428fdd37086a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b5b5fc26b280acefe728280db2f809e285f08fc2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3b29be627edc6b366eb635181619b77b9469d5dc --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..3cc805d00aa83e4ff8e43ef44f9d091ed359d548 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_paged_split_sm80.cu" +#include "flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f1c298c4c409464d70fb1a30eedf54102b9b460 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..64895643d2009214caa4dbc2d61785d98103bd68 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..dd508590d66f6ebde86b709a4dbb961f1cdfb9f4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..8411b6fccbdf0d1d1a7d140cdd3924e220df5be0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b5b4f40770e5576d45d68e43fe12de393ed10e76 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..b7972b68eb925c3615b4a951779c7b3605f8cfff --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_sm80.cu" +#include "flash_fwd_hdim64_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..e608da04b022d3286af44a87b879c19f39f13870 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c69b78ac3b6015cccc0fc7a39df15bae0230afde --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..170cdb5cb8c6e6795aa75db821c8d22a01079c67 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef0d1e921c1cecf2fbbebd04ea2b7a1d46fe5eb0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..f8adb78ec202909f35e68515bbfb58c16a2973c6 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_split_sm80.cu" +#include "flash_fwd_hdim64_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6a7fc29dddaae4145b4d6e8f760032871e0e1c60 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..faeb6c487fc6fb459b6261a2a2fdb9f124f9149a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..655258d5194232d03dbd292c41693fd0c63a8940 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4bd8ad8f267f5824d862d1343fd116359131850d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..657820f28547f075a26085c447b6f79aa7a43ace --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cb0955d1a53e1dd4f039dfb35efe4cb0302e8f2e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..357b64e83b67bf763b32e969505ef9686dc85291 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c1207925864503e30feeae9af161046c4ad57c5f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..21687f8932b72c2d6323e2346ce8bc0d74f62eb3 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4df8ed64d7be8192ee6ac05ac83dd8726796c419 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b601195d7e095577be2cc2f01425cbfc5d6d2ced --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ced475318983e384d1802f0826ff9d38165c52d9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..03090f73cb250b897f3392d41a1b41a28aa3ee39 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..d6fe1559ca11407e0a2c75bc6a50d1f501f8c1d2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b5ae4a56aa2ea5265bede2dc88e0e4cba3ebe35 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..4dd44519961b256eb33bb4727b0790d9bfc88946 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_paged_sm80.cu" +#include "flash_fwd_hdim64_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..6c603b4dcaaf6e4d6f70e25470a52cc69e1f8406 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..26d25fc1909b3be0251f7b7f4d4a8fa1e867aa9c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..05a0baf18b6a0dd2fbe6bfc08e104d1bb5811f7c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a45776537f1ba7ad56b7bf873d413bbdf0caa10 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..c7348b075735a222020f73383eba0bfadfd9a36f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_paged_split_sm80.cu" +#include "flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..9b80bae51f6e579b1993f976343b453145cf30de --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f6810efafb865253d66c2797e7fcc39bf446f87a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..98c018893f1e63c8d6e0e81edf0a622a15d7947b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..a10dfaca722bb629cdecb348a3800bd519e3f549 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b912a81443e2688aa2ada96b5181c00d0d9fcbb3 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e2756085ad05e72e5fce7809847434becdbccdd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_sm80.cu" +#include "flash_fwd_hdim64_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..8603c396e1f4f8edb28f5aaf632f10dec948d75f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..dc55dbc66aaabfddc6422614eb0055940923163b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef48844972a43736f44e7550adb9739c4287d178 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b1c0ead6e5c48dd41d6356f4cf56da69d64322a2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..fa4c8095449af058afbd457723e6f3ea2c28ca1d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_split_sm80.cu" +#include "flash_fwd_hdim64_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5d76d0fff0466c5eaf3c23a0b9abd1c7ee31c5be --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..44ea823d272b7f76d6be2b29aed8d7d15d2d6fe2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..30fe623508b14c4584c56d72456be3d26559f9d2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..6eb12dc80a60248305c0e2dc5e297380f0073361 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b806fc9d501dcb79c05728a02b8564bab147f8d2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..989d42b270840723d84286bb3474263557800d52 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_bf16_paged_sm80.cu" +#include "flash_fwd_hdim96_bf16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f0a26da03a116ea3b25982f40eeff3823cbea2b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6de2819a17293dab4d23d8cd875350b4d1627e8c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..16927295b8203e797f17bdbd5998677072b64e0f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0841307209230557940e4a27d36afaed02865df3 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..72654c12104cd08cd23c83321b9f2249a4af843b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_bf16_paged_split_sm80.cu" +#include "flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d4dcdc293bdfe23c1167f2af9da53d1b4745006 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b4dfbf7f8b2c58bd7a3c2a8573fc2a9b11badb5d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..1fa048752dcea50711ea473c27ab0b33e9fe30b2 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..e0b6a75e63581cef18e85418c7f26579037dd1d0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e257b42f79ccec1a9647ea73b6b882aa86a2f668 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..31268ea3dc58503cc7f8e42c2cb276494bdcafe5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_bf16_sm80.cu" +#include "flash_fwd_hdim96_bf16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..f97ab4733a08e66c7604e14a6b75de4628ecd952 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cee43ef94cdd9e3bb8803284f1d2137fb9d8b835 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..0442e1f94b533edc8d7cc548713b3ec84b6691cb --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc71fa9e71fc89dd5c029faec448557def185ed4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..435607c03cf2ccd32bfe5928b47fccda4eb5b333 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_bf16_split_sm80.cu" +#include "flash_fwd_hdim96_bf16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b61dd71885d9d13ed6e6e4e9883606f4483a8afd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f47e1f5cdacf18f2758b0ff74a7208113ba5f17c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..215752f1b0621a3823bf8c854d287d35710afc27 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..207afc79242fec88b1f2dca95c1214b736613ed4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6c38c083384b5968c80d53aacf9d15b4bd8bd01c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..dc2eb35dc2984ac0d4c8599a2177ded3c92a1c28 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f04e8bca6f375f75e0973f673c66d099e4ab2ea7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2697f6910e7ec9e137c78c38c141facc62ba583e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e7a98b2e6ee45014645d64d4300b8141f12d0acc --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..98fb39c86ee5f593eb746dffaf8e3ff2227a27e7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..cb938ad93b05592b94ba5d9fa181a70a54fe7125 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..e2dc45c79c6cd21a2f5b22e75fbbbbfe9207dbb7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..64f99c05a3224317360ece6215e24a6297709602 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..3fdbbf23bacd71d9f876cb2dcca8e6db5c17c16f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ffe202ee394bce5d86c156bb854fed0c51dffb05 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ce908b2f97a02445e9bcc026b633f24c6ad8b17 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_fp16_paged_sm80.cu" +#include "flash_fwd_hdim96_fp16_paged_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..42740f0228b3e6295901bc3f99495817e1dd613b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..829929980d0d605c5a8c1555d2d8a436f4ca56f5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..d6a330432a4579785ecbaa3b338664ee9bd1ad9d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..39c774e6f7797bd9158a813af1c09e8829076bad --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..2d504315825da06c5b06cdf2c79d0f8dfb32d758 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_fp16_paged_split_sm80.cu" +#include "flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc54be11e6c3d446dae7401a0f159cb26ddfb280 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a68790500d8701f24d1a3e937dbd9733f715f867 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3bca3065c7f8117ea903e2b755ae379c4a6952f7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..985692b9fa1aa6040910b54a4a6e38cf526e0e21 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3c99cb6b5a007154f2f18071b5c67e0a63b2c149 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..1c9565de0d6d6f22c21040837753e16386033b6c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_fp16_sm80.cu" +#include "flash_fwd_hdim96_fp16_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..cf77a1ae819172eb1662f48464a1dcc6d298681d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f9a46a44dd5c0c96a9b50ca3268836aa77fa401b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..9b4dbbba58aefd8cf4498b1333d98e75c10ac81f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_SM8x +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..da5373fd13e6fc64ef45fa159a286dcff8af9e33 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM96 +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu new file mode 100644 index 0000000000000000000000000000000000000000..921fcd4292abf6d6dc0d9b6e8b16677f8f3c8325 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim96_fp16_split_sm80.cu" +#include "flash_fwd_hdim96_fp16_split_softcap_sm80.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8b659e8321b7ca811f5422d4fa09c97d605f2020 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c84d02b6d04c7daafaeb3e5eccc1e401bbe18712 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_paged_sm90.cu" +#include "flash_fwd_hdim96_bf16_paged_sm90.cu" +#include "flash_fwd_hdim128_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_bf16_paged_sm90.cu" +#include "flash_fwd_hdim256_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6aaf7d12f5677c5beace8562b950955a4da394e6 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..117121414197181eb850e2b73abd3c26ffd5c4f7 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6175723086cee981ed4cb7ff133db55db0c13f3a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..2aac1970b1b0322f086e01df74f7d5875c202479 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_sm90.cu" +#include "flash_fwd_hdim96_bf16_sm90.cu" +#include "flash_fwd_hdim128_bf16_sm90.cu" +#include "flash_fwd_hdim192_bf16_sm90.cu" +#include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..be0c5af080bf713b053d426d58c3753352d04568 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..fd5893c59f432e312162d9a8da24592c35ed09d1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim96_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim128_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bcde9c9458291f1ec9c1c7cf4705c9ad601a95a1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_split_sm90.cu" +#include "flash_fwd_hdim96_bf16_split_sm90.cu" +#include "flash_fwd_hdim128_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_bf16_split_sm90.cu" +#include "flash_fwd_hdim256_bf16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..160eb3a18e4817c1aed0f422cf5bba535ba5f981 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..28819a690a3b4341557d994ecd4021d9be0bd81e --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_packgqa_sm90.cu" +#include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" +#include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" +#include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" +#include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..933ad98271957d41f3d43f487d72a37467185c76 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_paged_sm90.cu" +#include "flash_fwd_hdim96_e4m3_paged_sm90.cu" +#include "flash_fwd_hdim128_e4m3_paged_sm90.cu" +#include "flash_fwd_hdim192_e4m3_paged_sm90.cu" +#include "flash_fwd_hdim256_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a934f7d9924eb7a1bfc7b5ed93a1e6c1c206e94d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu" +#include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" +#include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" +#include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8475e878ae2d017f97d46bb27955d1fa91ae5e58 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_paged_split_sm90.cu" +#include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" +#include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" +#include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" +#include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..dd1405b17f06c5d59a4a49c662ead4e4f8b9504c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..7e7d806c6d575c095fdb76d8fb82d27014cba1d4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_sm90.cu" +#include "flash_fwd_hdim96_e4m3_sm90.cu" +#include "flash_fwd_hdim128_e4m3_sm90.cu" +#include "flash_fwd_hdim192_e4m3_sm90.cu" +#include "flash_fwd_hdim256_e4m3_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f973a4e411d039e97f8ce6c855f25045fbcc1803 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..30390838d3901f1d95db65d5adfb263f4c9a6c30 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_softcap_sm90.cu" +#include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" +#include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" +#include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" +#include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b629bd2b3224a52238a7814768e94a6d3d398f1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_split_sm90.cu" +#include "flash_fwd_hdim96_e4m3_split_sm90.cu" +#include "flash_fwd_hdim128_e4m3_split_sm90.cu" +#include "flash_fwd_hdim192_e4m3_split_sm90.cu" +#include "flash_fwd_hdim256_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..818c7fafb7a0eac09887061adbcfe3737e777f9b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_e4m3_split_softcap_sm90.cu" +#include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" +#include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" +#include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6652824d0752d4247e407256e7c5ea4fded56c4a --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..05d11e2e2583660775720695fe88e21650d38d4d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_paged_sm90.cu" +#include "flash_fwd_hdim96_fp16_paged_sm90.cu" +#include "flash_fwd_hdim128_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_fp16_paged_sm90.cu" +#include "flash_fwd_hdim256_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b638138eb2623a7dc475431063a1cdc120ab6f44 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3619a2175f086159c98c22e5cca4e885c7320ae4 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a408ceacbd6ffa11d6a4502d83f6f397407c1d1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..eec11be9162c565066a8e8851fff349cef89d910 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_sm90.cu" +#include "flash_fwd_hdim96_fp16_sm90.cu" +#include "flash_fwd_hdim128_fp16_sm90.cu" +#include "flash_fwd_hdim192_fp16_sm90.cu" +#include "flash_fwd_hdim256_fp16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ca2a1e1b843c2b79d3bc15754720653f5aa94897 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..8cf31a8a85f76445d754416827c665af8da17df5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim96_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim128_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..5ee7ace63ace05d2b5a1b875652a2313f93c9fb0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_split_sm90.cu" +#include "flash_fwd_hdim96_fp16_split_sm90.cu" +#include "flash_fwd_hdim128_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_fp16_split_sm90.cu" +#include "flash_fwd_hdim256_fp16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4da0ee704eb8d736dcb44a5dfa4b13f62e82b8e0 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ddd8bf07c4a601c34d04b7c0175093d6b755defd --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9494c4f1d23fecdda5f219b432d0ebae9db9112 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4b2ec583cfdb4d89899945e2f88e33b5c010b096 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..306722d45865b08d7f59a1e9655822764fb2eb01 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e44b2d2465463f8aac1ed28631f27b94eb38bee6 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d52417daef3c6d8de6af777322e308ec104babf9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..6428c461aa9c485bcb9856b3f8dfe2a41a5a3522 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d0df6306e28f059c35350be22b9e71c212b4b94b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..e116d3ea7c7f98b778f27b7465b8976c664c4282 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_split_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bededf4a7d8f93d4503dcfac297c079d3d75880f --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..526a51fb71e0d98ff104187027313cd24fdfec4c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4e5d9cc4fe2f2cc873f789d4bdf47b12a7aa808d --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f553af139f22e59c1dfc6abdee828d9c8e2afe00 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..aa2a8260d2500c0db89600641b2948f81e68219c --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..bbc4449ba21cedad486767b2d985769e56193924 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..02ca85ad672277db5e85b827effb8a4737d08201 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d090fde972b919a50afd37b109cdc66e11226093 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..d48f60ad7e2fc8b133ed5bdb629d3aba93d1ae37 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..9dda19d1ceaf8fa43aaca9e87a1051509c15ee05 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..f3e51fc9ebdf1f37eb329a5aa39b0d7e326ac0e9 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..ea531027938b8f74f86a2f1a0c9fc8195b99e0d3 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..10d86e5e99c4cd379f94fba6127355a1b5204c04 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..375197ef75e54d468d8fed3524abaf50cd9fb075 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..4fc4831cf58d6f188cb21a9b3ca9824d276d49a5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..a3d94a163a9a2158c4b9ddf1c468c428f4847198 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..9663103ae117876652399c211bbb778a4d048069 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..b7d2b07ca840dfd9768811edec1848838aafac8b --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..471b5abaafc555ccdf8eeceb706e649358786db1 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..10f72182fa97128ff83db68d029ae4a2cbdce2d5 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_split_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..54db60c23b162904546c89e2f249deab2c59e942 --- /dev/null +++ b/flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/flash-attn/mainloop_bwd_sm80.hpp b/flash-attn/mainloop_bwd_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..23baae61731cdddf816e8089f3ad454cb49ca80c --- /dev/null +++ b/flash-attn/mainloop_bwd_sm80.hpp @@ -0,0 +1,915 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "seqlen.h" +#include "mask.h" +#include "mask.h" +#include "softmax.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopBwdSm80 { + + static constexpr int kStages = Stages; + static constexpr int kStages_dO = Stages_dO; + static_assert(kStages >= kStages_dO); + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup; + + static constexpr bool SdP_swapAB = SdP_swapAB_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool dQ_swapAB = dQ_swapAB_; + + static constexpr bool Q_dO_same_stages = kStages == kStages_dO; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr int NumMmaThreads = NumMmaWarps * cutlass::NumThreadsPerWarp; + static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler + + using MMA_Atom_Arch = std::conditional_t< + ArchTag::kMinComputeCapability >= 80, + std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >, + MMA_Atom + >; + + static_assert(NumMmaWarps % AtomLayoutMSdP == 0); + static_assert(NumMmaWarps % AtomLayoutNdKV == 0); + static_assert(NumMmaWarps % AtomLayoutMdQ == 0); + static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarps && SdP_swapAB && !dKV_swapAB; + static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarps && AtomLayoutMdQ == NumMmaWarps && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + + using AtomLayoutSdP = std::conditional_t< + !SdP_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + static constexpr bool MmaSdPEvenN = ((!SdP_swapAB ? kBlockN : kBlockM) / size<1>(AtomLayoutSdP{})) % 16 == 0; + using TiledMmaSdP = TiledMMA< + MMA_Atom_Arch, + AtomLayoutSdP, + Tile(AtomLayoutSdP{}))>, Int<(MmaSdPEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutSdP{}))>, _16>>; + + using AtomLayoutdKV = std::conditional_t< + !dKV_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + static constexpr bool MmadKVEvenN = ((!dKV_swapAB ? kHeadDim : kBlockN) / size<1>(AtomLayoutdKV{})) % 16 == 0; + using TiledMmadKV = TiledMMA< + MMA_Atom_Arch, + AtomLayoutdKV, + Tile(AtomLayoutdKV{}))>, Int<(MmadKVEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdKV{}))>, _16>>; + + using AtomLayoutdQ = std::conditional_t< + !dQ_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + static constexpr bool MmadQEvenN = ((!dQ_swapAB ? kHeadDim : kBlockM) / size<1>(AtomLayoutdQ{})) % 16 == 0; + using TiledMmadQ = TiledMMA< + MMA_Atom_Arch, + AtomLayoutdQ, + Tile(AtomLayoutdQ{}))>, Int<(MmadQEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdQ{}))>, _16>>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + + // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. + // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. + // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension + // changes the layout. + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutdO = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + // TODO: FA2 has a slightly different layout, does it matter? + Layout>, + Stride, _1>>{})); + using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); + + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN % 64 == 0 ? 64 : (kBlockN % 32 == 0 ? 32 : 16); + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + + // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, + // it's still a valid smem address. + using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; + using SmemLayoutLSEMma = std::conditional_t< + SdP_swapAB, + cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, + cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> + >; + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutQt = + decltype(cute::composition(SmemLayoutQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutdOt = + decltype(cute::composition(SmemLayoutdO{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutKt = + decltype(cute::composition(SmemLayoutK{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutPdSt = + decltype(cute::composition(SmemLayoutPdS{}, + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, _1{})))); + + // Thread layout, 256 or 384 threads per row + using R2SLayoutAtomdQaccum = Layout>>; + using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 1 vals per store + + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; + // For the case where the N dimension of MmaSdP is divisible by 8 but not by 16 + using SmemCopyAtomHalf = Copy_Atom; + // For the case where the N dimension of MmadQ is divisible by 8 but not by 16 + using SmemCopyAtomTransposedHalf = Copy_Atom; + // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. + // If PdS_major is MN, then we need to "transpose" the write. + // TODO: check this write + using R2SCopyAtomPdS = Copy_Atom, Element>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using GmemCopyStruct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL, + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using GmemCopyAtom = Copy_Atom; + + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per read + using GmemCopyAtomLSE = Copy_Atom; + using GmemLayoutAtomLSE = Layout>>; + using GmemTiledCopyLSE = decltype(make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{}, + Layout>{})); // Val layout, 4 vals per store + // So that we don't have to check if we overshot kBlockM when we load Q + // static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQKV = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + // These are tuned for speed. They don't affect correctness. + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + // For hdim 192, separating masking iterations results in register spills. + // static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; + static constexpr bool SeparateMaskingIterations = false; + // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then + // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each + // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep + // statistic for 2 rows. + // static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; + // static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; + static constexpr bool ShuffleLSE = SdP_swapAB && false; + static constexpr bool ShuffledPsum = SdP_swapAB && false; + + static constexpr bool Share_QV_Smem = V_in_regs; + using SmemP_t = std::conditional_t, cute::array_aligned>>; + + struct TensorStorageSharedQV : cute::aligned_struct<128> { + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_q; + }; + cute::array_aligned> smem_do; + cute::array_aligned, 128> smem_lse; + cute::array_aligned, 128> smem_dpsum; + SmemP_t smem_p; + cute::array_aligned> smem_ds; + }; + + struct TensorStorageSeparateQV : cute::aligned_struct<128> { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + cute::array_aligned, 128> smem_lse; + cute::array_aligned, 128> smem_dpsum; + SmemP_t smem_p; + cute::array_aligned> smem_ds; + }; + + using TensorStorage = std::conditional_t; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* const ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* const ptr_V; + ShapeQKV const shape_V; + StrideQKV const stride_V; + Element const* const ptr_dO; + ShapeQKV const shape_dO; + StrideQKV const stride_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale; + int const window_size_left, window_size_right, attention_chunk; + float const softcap_val; + int const num_batch; + int* const dq_semaphore; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* const ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* const ptr_V; + ShapeQKV const shape_V; + StrideQKV const stride_V; + Element const* const ptr_dO; + ShapeQKV const shape_dO; + StrideQKV const stride_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum stride_dQaccum; + cutlass::FastDivmod qhead_per_khead_divmod; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale, softmax_scale_log2; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; + float const softcap_val; + int const num_batch; + int *const dq_semaphore; + int const *const cu_seqlens_q = nullptr; + int const *const cu_seqlens_k = nullptr; + int const *const seqused_q = nullptr; + int const *const seqused_k = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + // In the backward, we need to multiply by + // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. + // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale + // (the original softmax_scale) at the end. + return {args.ptr_Q, args.shape_Q, args.stride_Q, + args.ptr_K, args.shape_K, args.stride_K, + args.ptr_V, args.shape_V, args.stride_V, + args.ptr_dO, args.shape_dO, args.stride_dO, + args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, + args.softmax_scale, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.window_size_left, args.window_size_right, attention_chunk_divmod, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.num_batch, args.dq_semaphore, + args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + FrgTensordKV& tdKrdK, + FrgTensordKV& tdVrdV, + int thread_idx, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); + + int n_block = get<0>(block_coord); + int bidh = get<1>(block_coord); + int bidb = get<2>(block_coord); + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto m_block_min_max = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); + int const m_block_min = get<0>(m_block_min_max); + int const m_block_max = get<1>(m_block_min_max); + // It's possible to have m_block_max <= m_block_min. Exit early + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); + Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); + Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); + Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_dO, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen_q ? bidb : 0); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + + GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); + auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation + GmemTiledCopyLSE gmem_tiled_copy_lse; + auto gmem_thr_copy_lse = gmem_tiled_copy_lse.get_thread_slice(thread_idx); + R2STiledCopydQaccum r2s_tiled_copy_dQaccum; + auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_QKV.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_QKV.partition_D(sdO); + Tensor tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE); + Tensor tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE); + Tensor tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum); + Tensor tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum); + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } + + TiledMmaSdP tiled_mma_SdP; + TiledMmadKV tiled_mma_dKV; + TiledMmadQ tiled_mma_dQ; + + auto thr_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); + auto thr_mma_dKV = tiled_mma_dKV.get_thread_slice(thread_idx); + auto thr_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" + // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, + // because some partition_fragment_A/B don't compile. + // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function + Tensor tdPrV = mma_partition_fragment_AB(thr_mma_SdP, sV); + + // Copy Atom retiling + auto smem_copy_atom_SdP_B = cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}); + auto smem_tiled_copy_QdO = cute::conditional_return(make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP), make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP)); + auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(thread_idx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + auto smem_tiled_copy_KV = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP), make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP)); + auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(thread_idx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + auto r2s_tiled_copy_PdS = make_tiled_copy_C(R2SCopyAtomPdS{}, tiled_mma_SdP); + auto r2s_thr_copy_PdS = r2s_tiled_copy_PdS.get_thread_slice(thread_idx); + Tensor tPsP = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sP, sPt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sdS, sdSt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(r2s_thr_copy_PdS); print(sP); printf("\n"); print(sPt); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } + + auto smem_copy_atom_dKV_B = cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}); + auto smem_tiled_copy_PdSt = cute::conditional_return(make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV), make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV)); + auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(thread_idx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_tiled_copy_QdOt = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV)); + auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(thread_idx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_tiled_copy_dS = cute::conditional_return( + make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_dQ), + make_tiled_copy_B(cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}), tiled_mma_dQ)); + auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(thread_idx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_tiled_copy_Kt = cute::conditional_return( + make_tiled_copy_B(cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}), tiled_mma_dQ), + make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dQ)); + auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(thread_idx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + // thr_mma_SdP.partition_C(sLSEMma) has shape (MMA=4, MMA_M, MMA_N, PIPE), we only take the col indices + // or row indices, depending on whether SdP_swapAB. + Tensor tSsLSEMma = logical_divide(thr_mma_SdP.partition_C(sLSEMma), Shape<_2>{}); // (2, 2, MMA_M, MMA_N, PIPE) + Tensor tSsLSE = group_modes<0, 2>(cute::conditional_return( + tSsLSEMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) + tSsLSEMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) + Tensor tSsdPsumMma = logical_divide(thr_mma_SdP.partition_C(sdPsumMma), Shape<_2>{}); + Tensor tSsdPsum = group_modes<0, 2>(cute::conditional_return( + tSsdPsumMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) + tSsdPsumMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } + // If we want to split the stats among the 8 threads that share the same rows. + static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tSsLSE))::value, 8); + + // Predicates + Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); + Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } + Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); + Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOsdO))); + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_dO); } + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod + ); + + { + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + // Predicates + Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); + Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + #pragma unroll + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_V); } + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + // static_assert(EvenN); // It simplifies the loading of K and V + // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit + // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. + // int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN + // ? seqlen_info.seqlen_k - n_block * kBlockN + // : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)); + // // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockN dimension + // flash::copy( + // gmem_tiled_copy_QKV, tVgV, tVsV, t0KVcKV, tKVpKV, seqlenk_row_limit); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { + bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tVpV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); + } + } + } + if constexpr (V_in_regs) { flash::cp_async_fence(); } + // flash::copy( + // gmem_tiled_copy_QKV, tKgK, tKsK, t0KVcKV, tKVpKV, seqlenk_row_limit); + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + if (EvenN || m < size<1>(tKsK) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { + bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tKpK(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); + } + } + } + flash::cp_async_fence(); + } + + if constexpr (V_in_regs) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); + Tensor tdPsV_copy_view = smem_thr_copy_KV.partition_S(sV); + cute::copy(smem_tiled_copy_KV, tdPsV_copy_view, tdPrV_copy_view); + __syncthreads(); // Sync to avoid loading Q to smem_q, which overlaps with smem_v + } + + // Do we need bound check to make sure the row doesn't go above kBlockM + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + + auto load_Q_LSE = [&] (int const m_block, int const smem_pipe_write) { + // if (cute::thread0()) { printf("Inside load_Q_LSE, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } + Tensor tQsQ_cur = tQsQ(_, _, _, smem_pipe_write); + Tensor tQgQ_cur = tQgQ(_, _, _, m_block); + // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit + // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. + // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM + // ? seqlen_info.seqlen_q - m_block * kBlockM + // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); + // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockM dimension + // flash::copy( + // gmem_tiled_copy_QKV, tQgQ(_, _, _, m_block), tQsQ_cur, t0QcQ, tQpQ, seqlenq_row_limit); + int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenM || m < size<1>(tQsQ) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { + bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tQgQ_cur(_, m, k), tQsQ_cur(_, m, k)); + } + } + } + Tensor tLSEgLSE_cur = tLSEgLSE(_, _, m_block); + Tensor tLSEsLSE_cur = tLSEsLSE(_, _, smem_pipe_write); + // We made sure LSE length is padded so we read `kBlockM` elements so that all + // elements in sLSE are filled. Without this we might have uninitialized sLSE values. + #pragma unroll + for (int m = 0; m < size<1>(tLSEsLSE); ++m) { + if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { + cute::copy(gmem_tiled_copy_lse, tLSEgLSE_cur(_, m), tLSEsLSE_cur(_, m)); + } + } + }; + + auto load_dO_dPsum = [&] (int const m_block, int const smem_pipe_write) { + // if (cute::thread0()) { printf("Inside load_dO_dPsum, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } + Tensor tdOsdO_cur = tdOsdO(_, _, _, smem_pipe_write); + Tensor tdOgdO_cur = tdOgdO(_, _, _, m_block); + // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM + // ? seqlen_info.seqlen_q - m_block * kBlockM + // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); + // flash::copy( + // gmem_tiled_copy_QKV, tdOgdO(_, _, _, m_block), tdOsdO_cur, t0QcQ, tQpQ, seqlenq_row_limit); + int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tdOsdO); ++m) { + // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenM || m < size<1>(tdOsdO) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { + bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tdOsdO); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tdOpdO(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); + } + } + } + Tensor tLSEgdPsum_cur = tLSEgdPsum(_, _, m_block); + Tensor tLSEsdPsum_cur = tLSEsdPsum(_, _, smem_pipe_write); + #pragma unroll + for (int m = 0; m < size<1>(tLSEsdPsum); ++m) { + if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { + cute::copy(gmem_tiled_copy_lse, tLSEgdPsum_cur(_, m), tLSEsdPsum_cur(_, m)); + } + } + }; + + int m_block = m_block_min; + + // Note, using the for_each() function here to ensure `stage` is of type Int. + for_each(make_int_sequence{}, [&] (auto stage) { + static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; + static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; + if constexpr (!Is_last_stage || kStages == 1) { + if (Is_first_stage || m_block + stage < m_block_max) { + load_Q_LSE(m_block + stage, stage); + } + } + // We want the fence outside the if statement to have a fixed number of cp.async commits. + // so that we can wait with the correct number of outstanding commits. + cute::cp_async_fence(); + if constexpr (stage < kStages_dO) { + if (Is_first_stage || m_block + stage < m_block_max) { + load_dO_dPsum(m_block + stage, stage); + } + cute::cp_async_fence(); + } + }); + + int smem_pipe_read = 0, smem_pipe_read_do = 0, smem_pipe_write = kStages - 1, smem_pipe_write_do = 0; + + auto load_Q_next = [&] { + // if (cute::thread0()) { printf("m_block = %d, m_block_max = %d, smem_pipe_write = %d\n", m_block, m_block_max, smem_pipe_write); } + if (m_block + (kStages > 1 ? kStages - 1 : 1) < m_block_max) { + load_Q_LSE(m_block + (kStages > 1 ? kStages - 1 : 1), kStages > 1 ? smem_pipe_write : 0); + } + cute::cp_async_fence(); + }; + + auto load_dO_next = [&] { + // int smem_pipe_write_do_cur = Q_dO_same_stages ? smem_pipe_write : smem_pipe_write_do; + if (m_block + kStages_dO < m_block_max) { + // load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do_cur : 0); + load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do : 0); + } + cute::cp_async_fence(); + }; + + clear(tdKrdK); + clear(tdVrdV); + + auto bwd_step = [&](int m_block, auto mask_fn) { + Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + clear(tSrS); + flash::cp_async_wait<(kStages > 1) ? 1 : 0>(); + __syncthreads(); + Tensor tSrQ = mma_partition_fragment_AB(thr_mma_SdP, sQ(_, _, _0{})); + Tensor tSrK = mma_partition_fragment_AB(thr_mma_SdP, sK); + // if (cute::thread0()) { print(tiled_mma_SdP); print(tSrS); printf("\n"); print(tSrQ); printf("\n"); print(tSrK); printf("\n"); print(tSsQ); printf("\n"); print(tSsK); printf("\n"); } + flash::gemm_sm80( + tSrS, tSrQ, tSrK, tSsQ(_, _, _, kStages > 1 ? smem_pipe_read : 0), tSsK, + tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, nullptr /*hook*/); + Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tSsLSE(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffleLSE) { + cute::copy(tSsLSE(_, kStages > 1 ? smem_pipe_read : 0), tLSErLSE); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values + tLSErLSE(i) = tSsLSE((thread_idx % 32) / 4 + i * 8, kStages > 1 ? smem_pipe_read : 0); + } + } + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } + + // Reshape tSrS from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh + // if (cute::thread0()) { print_tensor(scores); } + auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); + mask_fn(tSrS, m_block); + #pragma unroll + for (int mi = 0; mi < size<0>(scores); ++mi) { + float const lse_scaled = [&] { + if constexpr (!ShuffleLSE) return tLSErLSE(mi); + else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(scores); ++ni) { + scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); + } + } + + Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + clear(tdPrdP); + int smem_pipe_read_do_cur = Q_dO_same_stages ? smem_pipe_read : smem_pipe_read_do; + flash::cp_async_wait<(kStages_dO > 1) ? 1 : 0>(); + __syncthreads(); + auto hook = cute::conditional_return<(kStages > 1)>(load_Q_next, nullptr); + Tensor tdPrdO = mma_partition_fragment_AB(thr_mma_SdP, sdO(_, _, _0{})); + Tensor tdPrV_cur = cute::conditional_return(tdPrV, mma_partition_fragment_AB(thr_mma_SdP, sV)); + flash::gemm_sm80( + tdPrdP, tdPrdO, tdPrV_cur, tdPsdO(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tdPsV, + tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, hook); + Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tSsdPsum(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffledPsum) { + cute::copy(tSsdPsum(_, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tLSErdPsum); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + tLSErdPsum(i) = tSsdPsum((thread_idx % 32) / 4 + i * 8, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); + } + } + + // Reshape tdPrdP from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + float const dP_sum_cur = [&] { + if constexpr (!ShuffledPsum) return tLSErdPsum(mi); + else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); + if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } + } + } + // if (cute::thread0()) { print_tensor(dS); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = make_tensor_like(tSrS); + flash::convert_type_out(tSrS, rP); + if constexpr (!Mma_dKV_is_RS) { + Tensor tPaP = r2s_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(r2s_tiled_copy_PdS, tPaP, tPsP); + } + Tensor rdS = make_tensor_like(tdPrdP); + flash::convert_type_out(tdPrdP, rdS); + if constexpr (!Mma_dKV_is_RS) { __syncthreads(); } // Make sure P is written + // For hdim 64, It's faster to write to smem_dS first before the dV gemm + Tensor tdSadS = r2s_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(r2s_tiled_copy_PdS, tdSadS, tdSsdS); + + Tensor tdVrdO = mma_partition_fragment_AB(thr_mma_dKV, sdOt(_, _, _0{})); + Tensor tdVsdO_cur = tdVsdOt(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); + if constexpr (Mma_dKV_is_RS) { + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + flash::gemm_rs_sm80(tdVrdV, tdVrP, tdVrdO, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); + } else { + Tensor tdVrP = mma_partition_fragment_AB(thr_mma_dKV, sPt); + flash::gemm_sm80( + tdVrdV, tdVrP, tdVrdO, tdVsPt, tdVsdO_cur, + tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, nullptr); + } + // if (cute::thread0()) { print_tensor(tdVrdV); } + __syncthreads(); // make sure sdS is written + auto do_mma_dQ = [&] (auto hook) { + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + clear(tdQrdQ); + Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); + Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); + flash::gemm_sm80( + tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, + // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); + // if (cute::thread0()) { print_tensor(tdQrdQ); } + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); + Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); + static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + }; + // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration + if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); } + Tensor tdKrQ = mma_partition_fragment_AB(thr_mma_dKV, sQt(_, _, _0{})); + Tensor tdKsQ_cur = tdKsQt(_, _, _, kStages > 1 ? smem_pipe_read : 0); + if constexpr (Mma_dKV_is_RS) { + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + flash::gemm_rs_sm80(tdKrdK, tdKrdS, tdKrQ, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); + } else { + Tensor tdKrdS = mma_partition_fragment_AB(thr_mma_dKV, sdSt); + flash::gemm_sm80( + tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur, + tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next)); + } + if constexpr (kStages == 1) { + __syncthreads(); + do_mma_dQ(load_Q_next); + } + // if (cute::thread0()) { print_tensor(tdKrdK); } + + smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; + smem_pipe_read_do = smem_pipe_read_do < kStages_dO - 1 ? smem_pipe_read_do + 1 : 0; + smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; + smem_pipe_write_do = smem_pipe_write_do < kStages_dO - 1 ? smem_pipe_write_do + 1 : 0; + + }; + + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { + bwd_step(m_block, mask_fn); + } + } + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations + ? m_block_max + : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); + + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max_before_local_mask; ++m_block) { + bwd_step(m_block, mask_fn); + } + + if constexpr (Is_local && SeparateMaskingIterations) { + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max; ++m_block) { + bwd_step(m_block, mask_fn); + } + } + + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } + #pragma unroll + for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } + + return true; + } + +}; + +} // namespace flash diff --git a/flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp b/flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec34e20eca1f25bb2a5550976a79da474be361b7 --- /dev/null +++ b/flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,1038 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "named_barrier.hpp" +#include "seqlen.h" +#include "block.h" +#include "mask.h" +#include "softmax.h" +#include "utils.h" +#include "copy_sm90_bulk_reduce.hpp" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopBwdSm90 { + + static constexpr int kStages = Stages; + static constexpr int kStages_dO = Stages_dO; + static constexpr int kStages_dS = Stages_dS; + static_assert(kStages >= kStages_dO); + static_assert(Stages_dS == 1 || Stages_dS == kStages); + static_assert(!Mma_dP_is_RS || SdP_swapAB_); // If Mma_dP_is_RS, we need SdP_SwapAB + using ClusterShape = ClusterShape_; + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + + static constexpr bool SdP_swapAB = SdP_swapAB_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool dQ_swapAB = dQ_swapAB_; + + static constexpr bool Q_dO_same_stages = kStages == kStages_dO; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + + static_assert(ArchTag::kMinComputeCapability >= 90); + static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp * 2; + + static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0); + static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0); + static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0); + static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB; + static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + + static constexpr GMMA::Major PdS_Major = GMMA::Major::K; + // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN; + static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K; + + using TileShapeAtomSdP = std::conditional_t< + !SdP_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutSdP = std::conditional_t< + !SdP_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmaSdP = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutSdP{})); + + using TiledMmadPRS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutSdP{})); + + using TileShapeAtomdKV = std::conditional_t< + !dKV_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdKV = std::conditional_t< + !dKV_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmadKV = decltype(cute::make_tiled_mma( + std::conditional_t< + Mma_dKV_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdKV{})); + + using TileShapeAtomdQ = std::conditional_t< + !dQ_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdQ = std::conditional_t< + !dQ_swapAB, + Layout, Int, _1>>, + Layout, Int, _1>> + >; + using TiledMmadQ = decltype(cute::make_tiled_mma( + std::conditional_t< + Mma_dQ_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdQ{})); + + // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. + // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. + // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension + // changes the layout. + using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); // for dKV_Mma + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutdO = + decltype(tile_to_shape(SmemLayoutAtomQdO{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); + using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector, + Int>()); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}, Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80 + // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, + // it's still a valid smem address. + using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; + using SmemLayoutLSEMma = std::conditional_t< + SdP_swapAB, + cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, + cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> + >; + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutQt = + decltype(cute::composition(SmemLayoutQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutdOt = + decltype(cute::composition(SmemLayoutdO{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutKt = + decltype(cute::composition(SmemLayoutK{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutPdSt = + decltype(cute::composition(SmemLayoutPdS{}, + make_layout(make_shape(Int{}, Int{}, Int{}), + make_stride(Int{}, _1{}, Int{})))); + + // Thread layout, 256 or 384 threads per row + // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately. + using R2SLayoutAtomdQaccum = Layout, Int>>; + using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + using SmemLayoutdQaccum = Layout, Int>>; + + static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads; + // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. + // If PdS_major is MN, then we need to "transpose" the write. + using SmemCopyAtomPdS = Copy_Atom< + std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN), + std::conditional_t, + std::conditional_t + >, + Element + >; + + using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{}))); + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQKV = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; + + using TMA_QdO = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + take<0, 2>(SmemLayoutQ{}), + TileShape_MNK{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_K = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + SmemLayoutK{}, + TileShape_MNK{}, + ClusterShape{})); // no mcast for KV + + using TMA_V = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), + SmemLayoutV{}, + TileShape_MNK{}, + ClusterShape{})); // no mcast for KV + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync; + using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast(size(SmemLayoutK{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV = static_cast(size(SmemLayoutV{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesLSE = static_cast(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v / 8); + + // These are tuned for speed. They don't affect correctness. + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + // For hdim 192, separating masking iterations results in register spills. + static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; + // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then + // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each + // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep + // statistic for 2 rows. + static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; + static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; + static constexpr bool dQacc_use_TMA = kHeadDim < 256; + // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can + // do atomic add on one half before doing the other half of the MMA, to reduce register pressure. + static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2; + static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma"); + + static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); + static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); + // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA + static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128; + static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{}); + static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment"); + + // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment? + using SmemdQacc_t = std::conditional_t, cute::array_aligned>>; + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + struct TensorStorage : cute::aligned_struct { + cute::array_aligned, SmemAlignmentQKVdO> smem_k; + cute::array_aligned, SmemAlignmentV> smem_v; + SmemdQacc_t smem_dqacc; + cute::array_aligned, SmemAlignmentQKVdO> smem_q; + cute::array_aligned, SmemAlignmentQKVdO> smem_do; + cute::array_aligned, 128> smem_lse; + cute::array_aligned, 128> smem_dpsum; + SmemP_t smem_p; + cute::array_aligned, SmemAlignmentdS> smem_ds; + }; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* const ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* const ptr_V; + ShapeQKV const shape_V; + StrideQKV const stride_V; + Element const* const ptr_dO; + ShapeQKV const shape_dO; + StrideQKV const stride_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum const stride_dQaccum; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale; + int const window_size_left, window_size_right, attention_chunk; + float const softcap_val; + int const num_batch; + int* const dq_semaphore; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + }; + + // Device side kernel params + struct Params { + ShapeQKV const shape_Q; + ShapeQKV const shape_K; + ShapeQKV const shape_V; + ShapeQKV const shape_dO; + ElementAccum* const ptr_dQaccum; + ShapedQaccum const shape_dQaccum; + StridedQaccum stride_dQaccum; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_QdO tma_load_Q, tma_load_dO; + TMA_K tma_load_K; + TMA_V tma_load_V; + float const* const ptr_LSE_log2; + ShapeLSE const shape_LSE; + StrideLSE const stride_LSE_log2; + float const* const ptr_dPsum; + StrideLSE const stride_dPsum; + float const softmax_scale, softmax_scale_log2; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; + float const softcap_val; + int const num_batch; + int* const dq_semaphore; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_QdO tma_load_Q = make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + mQ, + SmemLayoutQ{}(_, _, _0{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along N mode for this M load, if any + Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO); + TMA_QdO tma_load_dO = make_tma_copy_A_sm90( + GmemTiledCopyQdO{}, + mdO, + SmemLayoutdO{}(_, _, _0{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along N mode for this M load, if any + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + TMA_K tma_load_K = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK, + SmemLayoutK{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for KV + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V); + TMA_V tma_load_V = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mV, + SmemLayoutV{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for KV + if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + // In the backward, we need to multiply by + // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. + // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale + // (the original softmax_scale) at the end. + return {args.shape_Q, args.shape_K, + args.shape_V, args.shape_dO, + args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, + args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, + args.softmax_scale, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.window_size_left, args.window_size_right, attention_chunk_divmod, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.num_batch, args.dq_semaphore, + args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load(Params const& params, + MainloopPipeline pipeline_q, + MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_write, + PipelineState_dO& smem_pipe_write_do, + SharedStorage &shared_storage, + SchedulerPrefetch const& scheduler_prefetch, + cute::tuple block_coord + ) { + + auto [n_block, bidh, bidb] = block_coord; + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); + // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { + scheduler_prefetch(); + return; + } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); + Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); + + int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); + Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) + + Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); + Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); + Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{})); + Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{})); + // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, + // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) + // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, + // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) + auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y); + auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); + Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); + Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); + auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, + group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA) + auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, + group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA) + auto bulk_copy = Copy_Traits{}; + + uint16_t mcast_mask_qdo = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{})); + } + } + + int m_block = m_block_min; + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + pipeline_q.producer_acquire(smem_pipe_write); + copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index())); + copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), + gLSE(_, m_block), sLSE(_, smem_pipe_write.index())); + } + + // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + + if (lane_predicate) { + // Copy K tile and V tile from GMEM to SMEM. + shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV); + copy(params.tma_load_K.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK); + copy(params.tma_load_V.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV); + + #pragma unroll (kHeadDim < 256 ? 2 : 1) + for (; m_block < m_block_max - 1; ++m_block) { + // If Q and dO have the same number of stages, we can use the same pipeline state variable + // to reduce registers + PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); + pipeline_do.producer_acquire(smem_pipe_write_do_cur); + copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); + copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), + gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); + if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } + ++smem_pipe_write; + pipeline_q.producer_acquire(smem_pipe_write); + copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index())); + copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), + gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index())); + } + } + scheduler_prefetch(); + if (lane_predicate) { + PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); + pipeline_do.producer_acquire(smem_pipe_write_do_cur); + copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), + tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); + copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), + gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); + if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } + ++smem_pipe_write; + } + if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_write) { + static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages"); + // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write + PipelineState smem_pipe_write_do = smem_pipe_write; + // Issue the epilogue waits + if (cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_q.producer_tail(smem_pipe_write); + pipeline_do.producer_tail(smem_pipe_write_do); + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) { + // Issue the epilogue waits + if (cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_q.producer_tail(smem_pipe_write); + pipeline_do.producer_tail(smem_pipe_write_do); + } + } + + template + CUTLASS_DEVICE void + store_dq(Params const& params, + SharedStorage &shared_storage, + cute::tuple block_coord + ) { + if constexpr (!dQacc_use_TMA) { return; } + + auto [n_block, bidh, bidb] = block_coord; + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, 0 /*sink_token_length*/); + // It's possible to have m_block_max <= m_block_min. Exit early + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { return; } + } + + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); + static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum); + + bool const is_varlen = Varlen && params.cu_seqlens_q; + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) + + int const num_batch = params.num_batch; + int const num_head = get<2>(params.shape_Q); + int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh; + using Barrier = cutlass::GenericBarrier; + bool const lane_predicate = cute::elect_one_sync(); + int m_block = m_block_min; + #pragma unroll 2 + for (; m_block < m_block_max; ++m_block) { + if constexpr (Deterministic) { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } + #pragma unroll + for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem + if (lane_predicate) { + SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); + tma_store_arrive(); + } + } + // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. + for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { + if (lane_predicate) { tma_store_wait(); } + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to + }); + if constexpr (Deterministic) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } + if constexpr (Is_local && Deterministic) { + constexpr int kBlockM = get<0>(TileShape_MNK{}); + int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); + #pragma unroll 2 + for (; m_block < m_block_global_max; ++m_block) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } + } + + CUTLASS_DEVICE void + mma_init() { + // We're not currently using this bc we're not using persistent scheduler + // // Tell producer (warp 0) that smem_k and smem_v are ready + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if constexpr (dQacc_use_TMA) { + if (warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to + } + } + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + MainloopPipeline pipeline_q, + MainloopPipeline_dO pipeline_do, + PipelineState& smem_pipe_read, + PipelineState_dO& smem_pipe_read_do, + FrgTensordKV& tdKrdK, + FrgTensordKV& tdVrdV, + int thread_idx, + int &work_idx, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); + + int n_block = get<0>(block_coord); + int bidb = get<2>(block_coord); + SeqlenInfo_t seqlen_info{ + bidb, get<0>(params.shape_Q), size<0>(params.shape_K), + params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k + }; + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, 0 /*sink_token_length*/); + // It's possible to have m_block_max <= m_block_min. Exit early + if constexpr (Is_causal || Is_local || Varlen) { + if (m_block_max <= m_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); + Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); + Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); + Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); + Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt); + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); + Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); + Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); + + static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and + stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and + size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + Layout warp_group_thread_layout_dq = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaSdP tiled_mma_SdP; + using TiledMmadP = std::conditional_t; + TiledMmadP tiled_mma_dP; + TiledMmadKV tiled_mma_dKV; + TiledMmadQ tiled_mma_dQ; + + auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx)); + auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); + auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx)); + + auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); + + R2STiledCopydQaccum r2s_tiled_copy_dQaccum; + auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); + Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ); + // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); } + + // Allocate "fragments/descriptors" + // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, + // because some partition_fragment_A/B don't compile. + // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function + Tensor tSrQ = mma_partition_fragment_AB(wg_mma_SdP, sQ); + Tensor tSrK = mma_partition_fragment_AB(wg_mma_SdP, sK); + Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_SdP, sdO); + Tensor tdPrV = mma_partition_fragment_AB(wg_mma_dP, sV); + Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dKV, sdOt); + Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dKV, sQt); + Tensor tdQrdS = mma_partition_fragment_AB(wg_mma_dQ, sdS); + Tensor tdQrK = mma_partition_fragment_AB(wg_mma_dQ, sKt); + + Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } + + // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices + // or row indices, depending on whether SdP_swapAB. + Tensor tLSEsLSE = cute::conditional_return( + group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) + group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) + Tensor tLSEsdPsum = cute::conditional_return( + group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), + group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } + // If we want to split the stats among the 8 threads that share the same rows. + static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + int bidh = get<1>(block_coord); + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + + // For the case where we do atomicAdd directly to gdQaccum instead of using TMA + bool const is_varlen = Varlen && params.cu_seqlens_q; + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), + params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod + ); + + int m_block = m_block_min; + + clear(tdKrdK); + clear(tdVrdV); + // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero; + + cutlass::ConsumerToken barrier_token = static_cast(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); } + + if constexpr (Mma_dP_is_RS) { + using SmemCopyAtomV = Copy_Atom; + auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); + Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV); + Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV)); + cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view); + } + + auto bwd_step = [&](int m_block, auto mask_fn) { + Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + consumer_wait(pipeline_q, smem_pipe_read); + flash::gemm(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); + Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffleLSE) { + cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values + tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index()); + } + } + Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); + PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return(smem_pipe_read, smem_pipe_read_do); + consumer_wait(pipeline_do, smem_pipe_read_do_cur); + flash::gemm(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP); + warpgroup_wait<1>(); + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } + + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh + auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); + mask_fn(tSrS, m_block); + #pragma unroll + for (int mi = 0; mi < size<0>(scores); ++mi) { + float const lse_scaled = [&] { + if constexpr (!ShuffleLSE) return tLSErLSE(mi); + else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(scores); ++ni) { + scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); + } + } + + Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor(Int{})); + if constexpr (!ShuffledPsum) { + cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum); + } else { + #pragma unroll + for (int i = 0; i < kStatsPerThread; ++i) { + tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index()); + } + } + + warpgroup_wait<0>(); + // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + float const dP_sum_cur = [&] { + if constexpr (!ShuffledPsum) return tLSErdPsum(mi); + else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); + }(); + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); + if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } + } + } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = make_tensor_like(tSrS); + flash::convert_type_out(tSrS, rP); + if constexpr (!Mma_dKV_is_RS) { + // Need to sync to make sure P has already been used in the previous iteration before writing new values + if constexpr (kStages_dS == 1) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + } + Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); + } + Tensor rdS = make_tensor_like(tdPrdP); + flash::convert_type_out(tdPrdP, rdS); + // If there's double buffering on dS, we don't need to sync here. + // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. + // But because both WGs have to sync at the end of the loop and double buffering, + // this race condition is not possible. + // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and + // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. + if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + } + // For hdim 64, It's faster to write to smem_dS first before the dV gemm + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); + + if constexpr (!Slice_dQKV_Mma) { + // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure + if constexpr (Mma_dKV_is_RS) { + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + } else { + Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); + Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + } + // SMEM fence to make sure sdS is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ + + if constexpr (Mma_dKV_is_RS) { + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + } else { + Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); + Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + } + if constexpr (dQacc_use_TMA) { + int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1; + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem + Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); + cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem + } else { + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); + Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); + static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + } + + } else { // Slice_dQKV_Mma + + static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS)); + Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); + Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); + flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); + Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); + Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + + Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); + Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); + flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO + + flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); + #pragma unroll + for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + + flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); + } + + warpgroup_wait<0>(); + pipeline_q.consumer_release(smem_pipe_read); // release Q + ++smem_pipe_read; + if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; } + }; + + // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 + // this helps quite a bit to not have to do causal masking for most of the iterations. + if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { + bwd_step(m_block, mask_fn); + } + } + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations + ? m_block_max + : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); + + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max_before_local_mask; ++m_block) { + bwd_step(m_block, mask_fn); + } + + if constexpr (Is_local && SeparateMaskingIterations) { + auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block < m_block_max; ++m_block) { + bwd_step(m_block, mask_fn); + } + } + + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } + #pragma unroll + for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } + + if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; } + ++work_idx; + return true; + } + +}; + +} // namespace flash diff --git a/flash-attn/mainloop_fwd_sm80.hpp b/flash-attn/mainloop_fwd_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..297927bfda165a688e7e770b1ab7917d438c1695 --- /dev/null +++ b/flash-attn/mainloop_fwd_sm80.hpp @@ -0,0 +1,855 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "seqlen.h" +#include "block.h" +#include "mask.h" +#include "pack_gqa.h" +#include "paged_kv.h" +#include "rotary.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwdSm80 { + + static constexpr int kStages = Stages; + static_assert(kStages > 0, "kStages must be greater than 0"); + using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PagedKV = PagedKV_; + static constexpr bool AppendKV = AppendKV_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool Transpose_V = Is_FP8; + + static_assert(ArchTag::kMinComputeCapability >= 80); + + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + + using MMA_Atom_Arch = std::conditional_t< + ArchTag::kMinComputeCapability >= 80, + std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >, + MMA_Atom + >; + using TiledMma = TiledMMA< + MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + static constexpr int NumMmaThreads = size(TiledMma{}); + static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + using SmemLayoutAtomQKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomQKV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutV = decltype(tile_to_shape( + SmemLayoutAtomQKV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutVt = decltype( + composition(SmemLayoutV{}, + make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + Step<_2, _1, _3>{}))); + + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using GmemCopyAtom = Copy_Atom, + AutoVectorizingCopyWithAssumedAlignment<128> + >, Element>; + + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per read + // So that we don't have to check if we overshot kBlockM when we load Q + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + + // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of + // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), + // each thread will load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, "kGmemThreadsPerRowAppend must divide NumThreadsPerWarp"); + using GmemLayoutAtomAppend = Layout, Int>, + Stride, _1>>; + // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication + static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend"); + using GmemTiledCopyAppendKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomAppend{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQK = cute::Stride; + using StrideV = StrideQK; + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + using StrideDescale = cute::Stride; + + static constexpr bool Share_QV_Smem = Q_in_regs; + + struct TensorStorageSharedQV : cute::aligned_struct<128> { + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_q; + }; + cute::array_aligned> smem_k; + }; + + struct TensorStorageSeparateQV : cute::aligned_struct<128> { + cute::array_aligned> smem_v; + cute::array_aligned> smem_k; + cute::array_aligned> smem_q; + }; + + using TensorStorage = std::conditional_t; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + Element* const ptr_K; // Not Element const* since we might append to KV cache in-place + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + float const softmax_scale; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; + float const softcap_val; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + ShapeQPacked const shape_Q_packed; + StrideQPacked const stride_Q_packed; + Element* const ptr_K; + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod qhead_per_khead_divmod; + float const softmax_scale_log2; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + float const softcap_val; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); + auto const shape_Q_packed = cute::conditional_return( + args.shape_Q, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) + ); + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) + ); + if (get<1>(args.shape_rotary) > 0) { + assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); + } + assert(args.num_splits >= 1); + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, + args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, + args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, + args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, + cutlass::FastDivmod(int(get<0>(args.shape_K))), + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, + args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.window_size_left, args.window_size_right, attention_chunk_divmod, + !Split ? 1 : args.num_splits, + args.kv_batch_idx, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + int const n_block_min = get<0>(n_block_min_max); + int const n_block_max = get<1>(n_block_min_max); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + + GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); + auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation + + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(thread_idx); + + // Allocate "fragments/descriptors" + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); + + // Copy Atom retiling + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); + auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx); + auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // Predicates + Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); + Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int n_block = n_block_max - 1; + + // Prologue: load Q, K, V + // If persistent, we don't need to wait for the previous work_idx to finish + // since we assume that all MMA threads sync in the epilogue before writing to smem_o. + // So any thread gets there, all threads must have finished the previous MMA and at least started + // writing to smem_o. + // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v + if constexpr (Share_QV_Smem) { __syncthreads(); } + if constexpr (!PackGQA) { + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); + Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } + // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit + // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})) + ); + } else { + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>; + PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block); + } + cute::cp_async_fence(); + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ + ); + + auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + if constexpr (!PagedKV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write); + // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit + // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. + int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_info.seqlen_k - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN))); + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy( + gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit); + } else { + paged_kv_manager.template load_page_table(n_block); + paged_kv_manager.template load_K(n_block, sK(_, _, smem_pipe_write)); + } + }; + + auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + if constexpr (!PagedKV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write); + // We don't call flash::copy since it doesn't support bound checking + // to not overshot kBlockN when writing to smem. + Tensor tVgV_cur = tVgV(_, _, _, n_block); + int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { + bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k)); + } + } + } + } else { + paged_kv_manager.template load_V(n_block, sV(_, _, smem_pipe_write)); + } + }; + + auto preprocess_Q = [&] { + if constexpr (!AppendKV) { + flash::cp_async_wait(); + } else { + if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); + int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + flash::cp_async_wait(); + __syncthreads(); + rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead); + } else { + auto [tRrCosCont, tRrSinCont] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + flash::cp_async_wait(); + __syncthreads(); + rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); + } + } else { + flash::cp_async_wait(); + } + } + + if constexpr (Q_in_regs) { + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ); + cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); + } + }; + + // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and + // read from smem_q to registers, then load V. + // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q. + + if constexpr (Share_QV_Smem) { + load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/); + cute::cp_async_fence(); + preprocess_Q(); + __syncthreads(); // Make sure all threads have read smem_q before loading V + } + + // For persistent, make sure all threads have finished reading smem_o + if constexpr (!Share_QV_Smem) { __syncthreads(); } + // Note, using the for_each() function here to ensure `stage` is of type Int. + for_each(make_int_sequence{}, [&] (auto stage) { + static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; + static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; + if constexpr (!Share_QV_Smem || !Is_first_stage) { + if (Is_first_stage || n_block - stage >= n_block_min) { + load_K(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + // We want the fence outside the if statement to have a fixed number of cp.async commits. + // so that we can wait with the correct number of outstanding commits. + cute::cp_async_fence(); + } + if constexpr (!Is_last_stage) { + if (Is_first_stage || n_block - stage >= n_block_min) { + load_V(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + } + }); + + if constexpr (!Share_QV_Smem) { preprocess_Q(); } + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod + ); + + float softcap_val = params.softcap_val; + if constexpr (Has_softcap && Is_FP8) { + float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; + float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; + softcap_val *= q_descale * k_descale; + } + // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + // -inf to e.g. -50.0, which can affect the attention softmax. + auto scoremod_premask_fn = [&](auto& tSrS) { + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } + }; + + int smem_pipe_read = 0, smem_pipe_write = kStages - 1; + + auto load_K_next = [&] { + if (n_block - kStages >= n_block_min) { + load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + }; + + auto sync = [&] { + flash::cp_async_wait(); + __syncthreads(); + }; + + clear(tOrO); + + auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { + static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; + static constexpr bool Check_inf = decltype(check_inf_type)::value; + Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); + clear(tSrS); + sync(); + auto load_V_next = [&] { + if (n_block - kStages + 1 >= n_block_min) { + load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + }; + Tensor tSrQ_cur = cute::conditional_return(tSrQ, thr_mma.partition_fragment_A(sQ)); + Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{})); + flash::gemm_sm80( + tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0), + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next + ); + smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; + scoremod_premask_fn(tSrS); + // Faster to load_K before gemm if we only have 1 stage + if constexpr (kStages == 1) { sync(); load_K_next(); } + mask_fn(tSrS, n_block); + Tensor scores_scale = softmax.template max_get_scale(tSrS); + softmax.template online_softmax(tSrS); + if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (kStages > 1) { sync(); } + Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{})); + flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + if constexpr (kStages > 1) { load_K_next(); } + smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; + }; + + auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + --n_block; + if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking + auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); + #pragma unroll 1 + for (; n_block >= n_block_min_causal_local_mask; --n_block) { + fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + } + } + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + #pragma unroll 1 + for (; n_block >= n_block_min_before_local_mask; --n_block) { + fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); + } + // Separate masking iterations on the left for local attention + if constexpr (Is_local) { + auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); + } + } + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + Tensor scores_scale = softmax.finalize(v_descale); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); } + return true; + } + + template + CUTLASS_DEVICE bool + store_kv_new(Params const& params, + int const thread_idx, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord + ) { + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + int const n_block_new_min = get<0>(n_block_new_min_max); + int const n_block_new_max = get<1>(n_block_new_min_max); + if (n_block_new_max <= n_block_new_min) { return false; } + + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); + + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; + Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + + Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; + Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + int const seqlen_k_new = seqlen_info.seqlen_k_new; + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, + // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ + ); + + static_assert(std::is_same_v); + static_assert(!PagedKV || std::is_same_v); + GmemTiledCopyQKV gmem_tiled_copy_kv_g2s; + auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx); + auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{}); // Only for index calculation + GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g; + auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx); + auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{}); // Only for index calculation + Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew); + Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK); + Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK); + Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV); + Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK); + Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK); + Tensor tKpKg2s = make_tensor(make_shape(size<2>(tKsKg2s))); + Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK); + Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK); + Tensor tKpKs2g = make_tensor(make_shape(size<2>(tKsKs2g))); + #pragma unroll + for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); } + #pragma unroll + for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); } + + auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write); + int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k_new - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); + // We don't need to clear the sK smem tiles since we won't write them out + flash::copy( + gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); + }; + + auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { + static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; + Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write); + int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k_new - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); + // We don't need to clear the sV smem tiles since we won't write them out + flash::copy( + gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); + }; + + auto store_K = [&] (int const n_block, int const smem_pipe_read) { + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + if (get<1>(params.shape_rotary) <= 0) { + Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read); + if constexpr (!PagedKV) { + Tensor tKgK_cur = tKgK(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) + ); + } else { + paged_kv_manager.store_K(n_block, tKsK_cur); + } + } else { + Tensor gK_cur = gK(_, _, n_block); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block); + } else { + auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + } + } + }; + + auto store_V = [&] (int const n_block, int const smem_pipe_read) { + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read); + if constexpr (!PagedKV) { + Tensor tVgV_cur = tVgV(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit); + } else { + paged_kv_manager.store_V(n_block, tVsV_cur); + } + }; + + int n_block = n_block_new_max - 1; + // Note, using the for_each() function here to ensure `stage` is of type Int. + for_each(make_int_sequence{}, [&] (auto stage) { + static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; + static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; + if (Is_first_stage || n_block - stage >= n_block_new_min) { + load_K_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v + if constexpr (Is_first_stage) { __syncthreads(); } + if constexpr (!Is_last_stage) { + if (Is_first_stage || n_block - stage >= n_block_new_min) { + load_V_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + } + }); + + int smem_pipe_read = 0, smem_pipe_write = kStages - 1; + #pragma unroll 1 + for (; n_block >= n_block_new_min; --n_block) { + if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } + flash::cp_async_wait(); + __syncthreads(); + store_K(n_block, kStages > 1 ? smem_pipe_read : 0); + if (n_block - kStages + 1 >= n_block_new_min) { + load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; + flash::cp_async_wait(); + __syncthreads(); + store_V(n_block, kStages > 1 ? smem_pipe_read : 0); + smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; + if (n_block - kStages >= n_block_new_min) { + load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); + } + cute::cp_async_fence(); + } + + return true; + + } + +}; + +} // namespace flash diff --git a/flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp b/flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 0000000000000000000000000000000000000000..536ff855fd4f18deb8289e7e1ffd08c279e7b887 --- /dev/null +++ b/flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,1717 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "named_barrier.hpp" +#include "seqlen.h" +#include "block.h" +#include "mask.h" +#include "pack_gqa.h" +#include "paged_kv.h" +#include "rotary.h" +#include "utils.h" +#include "sm90_pipeline_no_cluster.hpp" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwdSm90 { + + static constexpr int kStages = Stages; + using ClusterShape = ClusterShape_; + using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; + using Element = Element_; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; + static constexpr bool AppendKV = AppendKV_; + static constexpr bool HasQv = HasQv_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool V_colmajor = V_colmajor_; + static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; + static constexpr bool Use_TMA_Q = !PackGQA; + static constexpr bool Use_TMA_KV = !PagedKVNonTMA; + static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); + static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); + static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; + static constexpr bool LargeHeadDimV = kHeadDimV > 256; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; + static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); + static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); + static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); + + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. + // Leaving this option here for reference. + static constexpr bool MmaQK_is_RS = false; + // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); + static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); + + // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write + static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; + + using AtomLayoutQK = Layout, _1, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( + std::conditional_t< + !MmaQK_is_RS, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutQK{})); + using AtomLayoutPV = std::conditional_t< + !LargeHeadDimV, + AtomLayoutQK, + Layout, _1>> + >; + using TiledMmaPV = decltype(cute::make_tiled_mma( + std::conditional_t< + !MmaPV_is_RS, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutPV{})); + using TiledMmaQV = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutQK{})); + // For hdim64,512, WG1 can use RS but WG2 must use SS + using TiledMmaPV_RS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutPV{})); + + static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); + static constexpr int NumMmaThreads = size(TiledMmaPV{}); + static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); + static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); + static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + using SmemLayoutVt = decltype(tile_to_shape( + SmemLayoutAtomVt{}, + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + using SmemLayoutVtMma = decltype(tile_to_shape( + SmemLayoutAtomVtMma{}, + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); + using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutVMmaQV = decltype(tile_to_shape( + SmemLayoutAtomVMmaQV{}, + make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); + static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); + + // Only used if we're using cp.async to load V + using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int>()); + using SmemLayoutVCpAsync = decltype(tile_to_shape( + SmemLayoutAtomVCpAsync{}, + make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); + + using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + + // Only for LargeHeadDimV where WG0 sends WG1 the scales + using SmemLayoutScale = cute::Layout, Int>>; + + using SmemCopyAtomP = Copy_Atom; + + // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. + // For FP16/BF16 we don't do any transposing. + static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; + // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). + static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + using LDSM_value_shape = Shape<_2, _2, _1, _4>; + using LDSM_value_stride = Stride<_1, _2, _16, _4>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using S2RTiledCopyVt = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_value_shape = Shape<_1, _4, _2, _2>; + using STSM_value_stride = Stride<_0, _1, _4, _8>; + using STSM_divide_shape = Shape<_8, _16>; + // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts + // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). + // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. + // using STSM_value_shape = Shape<_2, _4, _1, _2>; + // using STSM_value_stride = Stride<_4, _1, _0, _8>; + // using STSM_divide_shape = Shape<_16, _16>; + using R2STiledCopyV = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); + + // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved + // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will + // load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication + static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow"); + using GmemTiledCopyAppendKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQK = cute::Stride; + using StrideV = std::conditional_t>; + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + using StrideDescale = cute::Stride; + + using TMA_Q = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{})); + + using TMA_K = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{})); // mcast along M mode for this N load, if any + + using TMA_V = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_Qv_ = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{})); + using TMA_Qv = std::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); + + using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; + using MainloopPipelineK = std::conditional_t>; + using MainloopPipelineV = std::conditional_t>; + using MainloopPipelineVt = std::conditional_t>; + // We always use TMA for K_new and V_new + using MainloopPipelineKVNew = PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned + // and have sQ being position_independent_swizzle_tensor. + // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); + static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); + static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); + static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); + static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); + + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; + using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; + // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes + // smem size to go from 227KB to 228KB and we get "invalid argument". + + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + }; + + struct TensorStorageWithPNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemP_t smem_p; + }; + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemP_t smem_p; + SmemScale_t smem_scale; + }; + + using TensorStorageNoTranspose = std::conditional_t< + MmaPV_is_RS, + TensorStorageWithoutPNoTranspose, + std::conditional_t + >; + + static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); + static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); + struct TensorStorageTransposeV : cute::aligned_struct { + cute::array_aligned, SmemAlignmentV> smem_v; + cute::array_aligned, SmemAlignmentVt> smem_vt; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemScale_t smem_scale; + }; + + using TensorStorage = std::conditional_t; + + // These are tuned for speed. They don't affect correctness. + static constexpr bool UseSchedulerBarrier = (IntraWGOverlap + ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) + : NumMmaWarpGroups == 2) + && !LargeHeadDimV; + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + Element* const ptr_K; // not Element const* since we might append to KV cache in-place + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + float const softmax_scale; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; + float const softcap_val; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + ShapeQPacked const shape_Q_packed; + StrideQPacked const stride_Q_packed; + Element* const ptr_K; + ShapeQKV const shape_K; + StrideQK const stride_K; + Element* const ptr_V; + int32_t const headdim_v; + StrideV const stride_V; + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideV const stride_Qv; + ShapeQPacked const shape_Qv_packed; + StrideQPacked const stride_Qv_packed; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod blockN_per_page_size_divmod; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_Q tma_load_Q; + TMA_K tma_load_K; + TMA_V tma_load_V; + TMA_K tma_load_K_new; + TMA_V tma_load_V_new; + TMA_Qv tma_load_Qv; + float const softmax_scale_log2; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + float const softcap_val; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const *const seqlens_rotary = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_Q tma_load_Q = make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + TMA_K tma_load_K = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK, + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), + select<1, 0, 2, 3>(args.stride_V)); + TMA_V tma_load_V = make_tma_copy( + GmemTiledCopyKV{}, + mV, + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); + TMA_K tma_load_K_new = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + cute::conditional_return(mKnew, mK), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), + make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), + select<1, 0, 2, 3>(args.stride_V_new)); + TMA_V tma_load_V_new = make_tma_copy( + GmemTiledCopyKV{}, + cute::conditional_return(mVnew, mV), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); + Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); + TMA_Qv tma_load_Qv = [&] { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQv, + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); // no mcast for Qv + } else { + return nullptr; + } + }(); + // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); + auto const shape_Q_packed = cute::conditional_return( + args.shape_Q, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) + ); + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) + ); + auto const shape_Qv_packed = cute::conditional_return( + shape_Qv, + make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) + ); + auto const stride_Qv_packed = cute::conditional_return( + args.stride_Qv, + make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) + ); + if (get<1>(args.shape_rotary) > 0) { + assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); + } + assert(args.num_splits >= 1); + int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); + if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { + assert(page_size % kBlockN == 0); + assert(!args.leftpad_k); + } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, + args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, + args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, + args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, + args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, + cutlass::FastDivmod(page_size), // page_size_divmod + cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, + args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.window_size_left, args.window_size_right, attention_chunk_divmod, + !Split ? 1 : args.num_splits, + args.kv_batch_idx, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA_Q) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + if constexpr (HasQv) { + cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); + } + } + if constexpr (Use_TMA_KV) { + cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); + } + if constexpr (AppendKV) { + cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + load(Params const& params, + MainloopPipelineK pipeline_k, + MainloopPipelineV pipeline_v, + MainloopPipelineVt pipeline_vt, + PipelineState& smem_pipe_write, + SharedStorage &shared_storage, + SchedulerPrefetch const& scheduler_prefetch, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int &work_idx + ) { + + // some of these are captured in lambda so can't use structured binding + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { + scheduler_prefetch(); + return; + } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sK_pi = as_position_independent_swizzle_tensor(sK); + // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose. + // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes. + Tensor sVt = [&] { + if constexpr (!Transpose_V) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + } else { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{})); + } + }(); + // Only used if Transpose_V + Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{})); + // Only used if we're using cp.async to load V + Tensor sVcpasync = [&] { + if constexpr (!Transpose_V) { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); + } else { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); + } + }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + + int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } + Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) + + auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + if (Use_TMA_Q && thread_idx == 0) { prefetch(params.tma_load_Q, tQgQ); } + // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually + auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); + Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) + Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) + auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); + Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) + Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); + + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx + ); + + // Set up for transposing V, only used if Transpose_V + S2RTiledCopyVt s2r_tiled_copy_vt; + R2STiledCopyV r2s_tiled_copy_v; + auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(thread_idx); + auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(thread_idx); + // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8, kStages) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages) + // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages) + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages) + CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); + // Faster to have 2 LDSM.T, byte permute, STSM for better ILP + static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; + Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) + Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) + auto transpose_V = [&](int stage) { + if constexpr (Transpose_V) { + #pragma unroll + for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { + Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{})); + static_assert(size<0>(tTransrV) == 16); + Tensor tTransrV_64 = recast(tTransrV); + cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV); + #pragma unroll + for (int j = 0; j < size(tTransrV_64); ++j) { + uint32_t upper = tTransrV_64[j].x; + uint32_t lower = tTransrV_64[j].y; + tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); + tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + } + cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage)); + } + } + }; + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { + pipeline_k.producer_acquire(smem_pipe_write); + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); + copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), + tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); + } else { + constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); + pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { + auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); + pipeline_v_load.producer_acquire(smem_pipe_write); + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); + copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), + tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); + } else { + constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; + paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); + pipeline_v_load.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto copy_Vt_to_V = [&] (auto const& smem_pipe_write) { + // Instead of maintaining smem_pipe_read as a separate variable, we can just use smem_pipe_write, + // and exploit the invariance that smem_pipe_write.phase() == smem_pipe_read.phase() ^ 1. + // This saves 1 or 2 registers. + PipelineState smem_pipe_read{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()}; + pipeline_vt.consumer_wait(smem_pipe_read); + pipeline_v.producer_acquire(smem_pipe_write); + transpose_V(smem_pipe_write.index()); + // SMEM fence to make sure V is transposed before math + cutlass::arch::fence_view_async_shared(); + pipeline_v.producer_commit(smem_pipe_write); + // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized + // before calling. Without this we get race conditions. + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/); + pipeline_vt.consumer_release(smem_pipe_read); + }; + + int n_block = n_block_max - 1; + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); + + if (should_load_KV) { + if constexpr (PagedKVNonTMA) { + paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.template load_page_table_TMA(n_block); + } + if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());} + load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); + // if (thread_idx == 0) { printf("Producer: main load, after load K, index = %d\n", smem_pipe_write.index());} + } + + if constexpr (Use_TMA_Q) { + // Wait for the MMA warpgroups to signal that smem_q is ready + if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + + if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { + shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } + } + } else { // Load Q with cp.async + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); + barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } + } + + // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. + // if (thread_idx == 0) { printf("Producer: main load, before barrier_O, work_idx = %d\n", work_idx);} + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");} + + if constexpr (!Transpose_V && !IntraWGOverlap) { + if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + } + int n_block_prev = n_block; + --n_block; + #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1) + for (; n_block >= n_block_min; --n_block) { + PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind + ++smem_pipe_write; + if (should_load_KV) { + if constexpr (PagedKVNonTMA) { + paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.load_page_table_TMA(n_block); + } + if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } + load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); + if constexpr (!Transpose_V) { + if constexpr (IntraWGOverlap) { + load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); + } else { + load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); + } + } + } + n_block_prev = n_block; + if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } + } + scheduler_prefetch(); + if constexpr (!Transpose_V && IntraWGOverlap) { + if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + } + if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); } + ++smem_pipe_write; + // At the end, all threads have the correct smem_pipe_write. + ++work_idx; + } + + template + CUTLASS_DEVICE void + load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt, + PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) { + // If we don't wait for barrier_O here, when using Cluster, CTA0 might exit early and CTA1 will + // try to arrive on barrier_O of CTA0, causing "unspecified launch failure". + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // Issue the epilogue waits + // TODO: check if this should be called by 1 thread or more + if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_k.producer_tail(smem_pipe_write); + pipeline_v.producer_tail(smem_pipe_write); + if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); } + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_arrive() { + if constexpr (UseSchedulerBarrier) { + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1; + int const next_WG = NumMmaWarpGroups == 2 + ? 1 - cur_WG + : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0); + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/); + } + } + + CUTLASS_DEVICE void + mma_init() { + int warp_group_idx = flash::canonical_warp_group_idx_nosync(); + // Tell producers that smem_q is ready + if (!LargeHeadDimV || warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + if (LargeHeadDimV && warp_group_idx > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + if constexpr (UseSchedulerBarrier) { + // We have NamedBarrier for up to 3 WGs + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + // WG1 needs the very first signal to start + if (warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + MainloopPipelineK pipeline_k, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + int &work_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = [&] { + if constexpr (MmaPV_is_RS) { + // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + } + }(); + Tensor sScale = [&] { + if constexpr (LargeHeadDimV) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + } else { // won't be used, just a placeholder + return make_tensor(make_smem_ptr(static_cast(nullptr)), SmemLayoutScale{}); + } + }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); + + if constexpr (!MmaQK_is_RS) { + static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and + stride<0>(typename TiledMmaQK::BLayout{}) == 0 and + size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + } + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + TiledMmaQV tiled_mma_qv; + auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); + auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" + Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); + Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); + Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + + // For storing scales to smem, only used when LargeHeadDimV + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto store_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + if (get<1>(taccOcO_row(_0{})) == 0) { + sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); + } + } + }; + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int n_block = n_block_max - 1; + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod + ); + + float softcap_val = params.softcap_val; + if constexpr (Has_softcap && Is_FP8) { + float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; + float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; + softcap_val *= q_descale * k_descale; + } + // Softcapping needs to happen before masking since if we apply after masking, softcapping + // can turn -inf to e.g. -50.0, which can affect the attention softmax. + auto scoremod_premask_fn = [&](auto& tSrS) { + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } + }; + + auto write_P_to_smem = [&](auto& tOrP) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); + }; + + auto arrive_on_P_write_barrier = [&] { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + }; + + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + if constexpr (!AppendKV) { + barrier_Q.wait(work_idx % 2); + } else { + if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); + Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); + int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + barrier_Q.wait(work_idx % 2); + rotary.apply_Q_interleaved(sQ_pi, tRrCos, tRrSin, m_block, qhead_per_khead); + } else { + auto [tRrCosCont, tRrSinCont] = cute::conditional_return( + rotary.template load_cos_sin(m_block), + rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) + ); + barrier_Q.wait(work_idx % 2); + rotary.apply_Q_contiguous(sQ_pi, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); + } + // SMEM fence to make sure the rotated Q is visible to GMMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + } else { + barrier_Q.wait(work_idx % 2); + } + } + + if constexpr (MmaQK_is_RS) { + using SmemCopyAtomQ = Copy_Atom; + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); + cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); + } + + if constexpr (IntraWGOverlap) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } + scoremod_premask_fn(tSrS); + mask.template apply(tSrS, m_block, n_block); + + Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + + softmax.template online_softmax(tSrS); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } + --n_block; + + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. + auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { + static constexpr bool Check_inf = decltype(check_inf_type)::value; + PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); + ++smem_pipe_read; + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } + scoremod_premask_fn(tSrS); + mask_fn(tSrS, n_block); + cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } + softmax.template online_softmax(tSrS); + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } + convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); + if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } + }; + + if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking + auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); + #pragma unroll 1 + for (; n_block >= n_block_min_causal_local_mask; --n_block) { + fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); + } + } + + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + #pragma unroll 1 + for (; n_block >= n_block_min_before_local_mask; --n_block) { + fwd_step(n_block, no_mask_fn, cute::false_type{} /*check_inf*/); + } + // Separate masking iterations on the left for local attention + if constexpr (Is_local) { + auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); + } + } + // Tell producers that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + + } else { // No intra-WG overlap + + warp_scheduler_barrier_sync(); + + auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { + static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; + static constexpr bool Check_inf = decltype(check_inf_type)::value; + auto smem_pipe_read_prev = smem_pipe_read; + if constexpr (!Is_first_iter) { ++smem_pipe_read; } + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + if constexpr (!HasQv) { + warp_scheduler_barrier_arrive(); + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + } else { + if constexpr (Is_first_iter) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + } + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + warpgroup_wait<0>(); + } + scoremod_premask_fn(tSrS); + mask_fn(tSrS, n_block); + Tensor scores_scale = softmax.template max_get_scale(tSrS); + if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); } + softmax.template online_softmax(tSrS); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } + if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + warp_scheduler_barrier_sync(); + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } + if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + + auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + --n_block; + if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking + auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); + #pragma unroll 1 + for (; n_block >= n_block_min_causal_local_mask; --n_block) { + fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + } + } + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + #pragma unroll 1 + for (; n_block >= n_block_min_before_local_mask; --n_block) { + fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); + } + // Separate masking iterations on the left for local attention + if constexpr (Is_local) { + auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); + } + } + warp_scheduler_barrier_arrive(); + // Tell producers that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + Tensor scores_scale = softmax.finalize(v_descale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + } + ++work_idx; + return true; + } + + template + CUTLASS_DEVICE bool + mma_pv(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaPV tiled_mma_pv; + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate "fragments/descriptors" + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + + // For load scales to smem, pretend thread_idx is thread_idx % 128 + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto load_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); + } + }; + + // clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + typename Softmax::TensorT scores_scale; + + int n_block = n_block_max - 1; + // If HasQv, then by the time P is ready, V must have been ready as well + if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + --n_block; + + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + softmax.rescale_o(tOrO, scores_scale); + ++smem_pipe_read; + if constexpr (!HasQv) { + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + } + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + return true; + } + + template + CUTLASS_DEVICE bool + load_kv_new(Params const& params, + MainloopPipelineKVNew pipeline_k_new, + MainloopPipelineKVNew pipeline_v_new, + PipelineState& smem_pipe_write, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int const work_idx + ) { + + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + + if (n_block_new_max <= n_block_new_min) { return false; } + + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sVt = [&] { + if constexpr (!Transpose_V) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); + } + }(); + + // int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; + Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + + Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + + auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); + Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) + Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) + auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); + Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) + Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_k_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + }; + + auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_v_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + }; + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); + + int n_block = n_block_new_max - 1; + // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV + // and the main attention are not the same. We want to make sure the consumers + // have finished reading all smem_k and smem_v for the previous iteration. + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + ++smem_pipe_write; + --n_block; + // if (thread_idx == 0) { printf("Producer: before for loop\n"); } + #pragma unroll 1 + for (; n_block >= n_block_new_min; --n_block) { + if (should_load_KV) { + load_K_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + load_V_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + } + ++smem_pipe_write; + } + // if (thread_idx == 0) { printf("Producer: after for loop\n"); } + // At the end, all threads have the correct smem_pipe_write. + return true; + } + + template + CUTLASS_DEVICE bool + store_kv_new(Params const& params, + MainloopPipelineKVNew pipeline_k_new, + MainloopPipelineKVNew pipeline_v_new, + PipelineState& smem_pipe_read, + int const thread_idx, + SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord + ) { + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + if (n_block_new_max <= n_block_new_min) { return false; } + + // as_position_independent_swizzle_tensor makes address calculation easier + Tensor sK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{})); + // We want to use SmemLayoutVCpAsync to have shape (kBlockN, kHeadDim) instead of (kHeadDim, kBlockN) + Tensor sV = [&] { + if constexpr (!Transpose_V) { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); + } else { + return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); + } + }(); + + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + + int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; + Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) + + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + int const seqlen_k_new = seqlen_info.seqlen_k_new; + using Rotary_t = Rotary; + Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, + params.ptr_rotary_sin, params.stride_rotary_sin, + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); + + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + PagedKVManager_t paged_kv_manager( + params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, + params.ptr_K, params.shape_K, params.stride_K, + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx + // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + ); + + if constexpr (UseSchedulerBarrier) { + // WG1 already got the very first signal from mma_init(), but we'll be using the same NamedBarrier. + // So we'll need to "cancel it out" here and then re-signal it at the end. + if (flash::canonical_warp_group_idx_nosync() == 1) { + cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + + static_assert(std::is_same_v); + static_assert(!PagedKVNonTMA || std::is_same_v); + GmemTiledCopyAppendKV gmem_tiled_copy_kv; + auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); + Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tKgK = gmem_thr_copy_kv.partition_D(gK); + Tensor tVsV = gmem_thr_copy_kv.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tVgV = gmem_thr_copy_kv.partition_D(gV); + Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_kv.partition_D(cK); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + #pragma unroll + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } + Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) + Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); + Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } + Tensor tVpV = cute::conditional_return(tKpK, tVpV_); + + auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + if (get<1>(params.shape_rotary) <= 0) { + pipeline_k_new.consumer_wait(smem_pipe_read); + Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); + if constexpr (!PagedKVNonTMA) { + Tensor tKgK_cur = tKgK(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv, tKsK_cur, tKgK_cur, tKcK, tKpK, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) + ); + } else { + paged_kv_manager.store_K(n_block, tKsK_cur); + } + } else { + Tensor gK_cur = gK(_, _, n_block); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + if (params.is_rotary_interleaved) { + auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); + pipeline_k_new.consumer_wait(smem_pipe_read); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); + } else { + auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); + pipeline_k_new.consumer_wait(smem_pipe_read); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + } + } + // Without this fence I'm getting race condition when seqlen_k is large + cutlass::arch::fence_view_async_shared(); + // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized + // before calling. + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + pipeline_k_new.consumer_release(smem_pipe_read); + // if (thread_idx == 0) { print_tensor(tKpK); printf("\n"); printf("seqlen_limit = %d\n", seqlen_k_new - n_block * kBlockN);} + }; + + auto store_V = [&] (int const n_block, auto const& smem_pipe_read) { + pipeline_v_new.consumer_wait(smem_pipe_read); + int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); + Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); + if constexpr (!PagedKVNonTMA) { + Tensor tVgV_cur = tVgV(_, _, _, n_block); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); + } else { + paged_kv_manager.store_V(n_block, tVsV_cur); + } + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + pipeline_v_new.consumer_release(smem_pipe_read); + }; + + #pragma unroll 1 + for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } + store_K(n_block, smem_pipe_read); + // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + store_V(n_block, smem_pipe_read); + // if (thread_idx == 0) { printf("Done storing V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + ++smem_pipe_read; + } + // if (thread_idx == 0) { printf("After for loop\n"); } + + // Re-signaling the NamedBarrier that we "canceled out" + if constexpr (UseSchedulerBarrier) { + if (flash::canonical_warp_group_idx_nosync() == 1) { + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + + return true; + + } + +}; + +} // namespace flash diff --git a/flash-attn/mask.h b/flash-attn/mask.h new file mode 100644 index 0000000000000000000000000000000000000000..d43e5ee156af61b1d13520d37e3e7f82a6407ee0 --- /dev/null +++ b/flash-attn/mask.h @@ -0,0 +1,166 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct Mask { + + static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); + + int const thread_idx; + int const seqlen_q, seqlen_k; + int const window_size_left, window_size_right, sink_token_length; + cutlass::FastDivmod const attention_chunk_divmod; + cutlass::FastDivmod const qhead_per_khead_divmod; + + CUTLASS_DEVICE + Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, + const int window_size_left, const int window_size_right, const int sink_token_length, + cutlass::FastDivmod const &attention_chunk_divmod, + cutlass::FastDivmod const &qhead_per_khead_divmod) + : thread_idx(thread_idx) + , seqlen_q(seqlen_q) + , seqlen_k(seqlen_k) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , sink_token_length(sink_token_length) + , attention_chunk_divmod(attention_chunk_divmod) + , qhead_per_khead_divmod(qhead_per_khead_divmod) + { + }; + + template + CUTLASS_DEVICE + void apply(Tensor &tSrS, const int m_block, const int n_block) const { + static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } + + auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); + auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); + + static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; + + Tensor cS = cute::make_identity_tensor(Shape, Int>{}); + Tensor tScS = thread_mma.partition_C(cS); + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); + Tensor t0ScS = thread0_mma.partition_C(cS); + Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); + // We want to use the col indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first col index of this thread (get(tScS_rowcol(_0{}, _0{}))) + int const thread_col_offset = get(tScS_rowcol(_0{}, _0{})); + int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; + if constexpr (!Causal_mask && !Local_mask) { + if constexpr (Seqlenk_mask) { // Just masking based on col + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + if (int(get(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } else { // mask based on both row and col + if constexpr (!SwapAB) { + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); + int mma_m_idx; + // Might get OOB but it's ok since we'll check it later + if constexpr (PackGQA) { + mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); + } + int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; + if constexpr (Causal_mask) { + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = !PackGQA + ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM + : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + int const col_limit_right = !Seqlenk_mask + ? row_idx + causal_row_offset + : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } else { + int const local_row_offset_right = causal_row_offset + window_size_right; + int const local_row_offset_left = causal_row_offset - 1 - window_size_left; + int const col_limit_sink = sink_token_length - n_block * kBlockN; // TODO: subtract thread_col_offset? + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = !PackGQA + ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM + : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + int col_limit_right = !Seqlenk_mask + ? row_idx + local_row_offset_right + : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); + int col_limit_left = row_idx + local_row_offset_left; + if (attention_chunk_divmod.divisor > 0) { + int col_limit_left_chunk = flash::round_down(attention_chunk_divmod, row_idx + seqlen_k - seqlen_q) - n_block * kBlockN - thread_col_offset; + col_limit_left = std::max(col_limit_left, col_limit_left_chunk); + col_limit_right = std::min(col_limit_right, col_limit_left_chunk + attention_chunk_divmod.divisor); + } + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = int(get(t0ScS_rowcol(m, n))); + if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } else { + // TODO: backward does not support attention_chunk yet + int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); + int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; + if constexpr (Causal_mask) { + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col0 = int(get(t0ScS_rowcol(_0{}, n))); + // If col0 is beyond the column limit, we want to mask out the entire column, by setting + // row limit to be kBlockM. + int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + if (int(get(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } else { + int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col0 = int(get(t0ScS_rowcol(_0{}, n))); + // If col0 is beyond the column limit, we want to mask out the entire column, by setting + // row limit to be kBlockM. + int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; + int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = int(get(t0ScS_rowcol(m, _0{}))); + if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/flash-attn/named_barrier.hpp b/flash-attn/named_barrier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a7dfb6439a234baffd8ffd154ab765c88021498c --- /dev/null +++ b/flash-attn/named_barrier.hpp @@ -0,0 +1,72 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/arch/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work +// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80. + +CUTLASS_DEVICE +static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) { + static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); + uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +} + +CUTLASS_DEVICE +static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { + uint32_t barrier_id = static_cast(reserved_named_barriers); + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +} + +CUTLASS_DEVICE +static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) { + static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); + uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + +CUTLASS_DEVICE +static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { + uint32_t barrier_id = static_cast(reserved_named_barriers); + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class FwdNamedBarriers { + QueryEmpty = 0, + WarpSchedulerWG1 = 1, + WarpSchedulerWG2 = 2, + WarpSchedulerWG3 = 3, + AppendKV = 4, + QueryRotated = 5, + PFull = 6, + PEmpty = 7, +}; + +enum class BwdNamedBarriers { + KVEmpty = 0, + PdS = 1, + dQEmptyWG1 = 2, + dQEmptyWG2 = 3, + dQEmptyWG3 = 4, + dQFullWG1 = 5, + dQFullWG2 = 6, + dQFullWG3 = 7, +}; + +} // flash diff --git a/flash-attn/pack_gqa.h b/flash-attn/pack_gqa.h new file mode 100644 index 0000000000000000000000000000000000000000..160bf4306840bb55e72bfcc5ae6d10c99356e595 --- /dev/null +++ b/flash-attn/pack_gqa.h @@ -0,0 +1,255 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct PackGQAManager { + // We use CpAsync for Q, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad; + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // In the case of PackGQA, this reduces the number of times we need to call divmod. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQCpAsync = decltype( + make_tiled_copy(GmemCopyAtomCpAsync{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need + // to sync within each WG, but didn't seem to be any faster. + // using GmemLayoutAtomWG = Layout, Int, Int >, + // Stride, _128, _1>>; + // using GmemTiledCopyQCpAsyncWG = decltype( + // make_tiled_copy(GmemCopyAtomCpAsync{}, + // GmemLayoutAtomNew{}, + // Layout>>{})); // Val layout, 8 or 16 vals per load + + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + template + CUTLASS_DEVICE + static auto + compute_ptr(Tensor &tensor, TensorC const &tRows, + cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) { + // tensor of shape ((qhead_per_khead, seqlen_q)) + static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow); + using TensorType = typename Engine::value_type; + Tensor tPrPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < NumPtrPerThread; ++i) { + int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow)); + int const idx = m_block * kBlockM + row; + int m_idx, h_idx; + m_idx = qhead_per_khead_divmod.divmod(h_idx, idx); + tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx))); + } + return tPrPtr; + } + + + template + CUTLASS_DEVICE + static void + load_Q(TensormQ const &mQ, // ((qhead_per_khead, seqlen_q), headdim) + TensorsQ &sQ, // (kBlockM, kHeadDim) + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_q, int const m_block + ) + { + GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async; + // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async; + auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_); + // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); } + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q. + // We split the work among threads loading the same row of Q, then __shfl_sync the pointers. + Tensor mQ_0 = mQ(_, _0{}); + Tensor tQcQ_row = tQcQ(_0{}, _, _0{}); + Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{})); + Element const* q_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < seqlen_q * qhead_per_khead) { + // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));} + Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape>{}); + Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad; + // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false + // TODO: check this + cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k)); + } + } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows + } + }; + + template + CUTLASS_DEVICE + static void + store_LSE(TensormLSE &mLSE, // ((qhead_per_khead, seqlen_q)) + TensorsLSE const &tLSErLSE, // (kBlockM) split across threads according to tiled_mma + TiledMma tiled_mma, + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{}); + CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); // MMA_M + + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow); + static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); + + Tensor tPrLSEPtr = compute_ptr(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int mi = 0; mi < size(tLSErLSE); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + float* ptr_LSE_cur = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) { + *ptr_LSE_cur = tLSErLSE(mi); + } + } + }; + + template + CUTLASS_DEVICE + static void + store_O(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) + TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor cO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); } + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. + // We split the work among threads loading the same row of O, then __shfl_sync the pointers. + Tensor mO_0 = mO(_, _0{}); + Tensor tOcO_row = tOcO(_0{}, _, _0{}); + Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tOrO); ++m) { + int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{})); + Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < seqlen_o * qhead_per_khead) { + Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); + Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOrO); ++k) { + int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore; + if (tOpO(k)) { + cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + static void + store_O_direct(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) + TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to tiled_mma + TiledMma tiled_mma, + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + static constexpr int kGmemElemsPerStoreDirect = 2; + cute::Copy_Atom, Element> gmem_copy_direct; + // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); + + Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); + + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. + // We split the work among threads loading the same row of O, then __shfl_sync the pointers. + Tensor mO_0 = mO(_, _0{}); + Tensor tPrOPtr = compute_ptr(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1); + + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tOrO_copy); ++m) { + int row = m_block * kBlockM + get<0>(taccOcO_row(m)); + Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow)); + if (row < seqlen_o * qhead_per_khead) { + Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); + Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOrO_copy); ++k) { + int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)); + if (col < size<1>(mO)) { + cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect)); + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/flash-attn/paged_kv.h b/flash-attn/paged_kv.h new file mode 100644 index 0000000000000000000000000000000000000000..9ea59bcc2a2378bf0de1d626410146a93cc74420 --- /dev/null +++ b/flash-attn/paged_kv.h @@ -0,0 +1,354 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct PagedKVManager { + // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), + // load_page_table(2), load_K(2), load_V(1), etc. + // So we need to compute the V pointers for the previous iteration. + + // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for + // rotary where we want each thread to have at least 2 loads per row. + + static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // In the case of PackGQA, this reduces the number of times we need to call divmod. + static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; + using GmemLayoutAtomKVCpAsync = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyKVCpAsync = decltype( + make_tiled_copy(GmemCopyAtomCpAsync{}, + GmemLayoutAtomKVCpAsync{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + using GmemTiledCopyKVStore = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomKVCpAsync{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + using ShapeKV = cute::Shape; // (seqlen, d, head, batch) + using StrideKV = cute::Stride; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + + using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _)); + using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _)); + using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); + using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); + + // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, + // since those require int64_t arithmetic. We optimize by having threads split this work. + // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows + // that each thread needs to load for the case of hdim 128 and kBlockN = 176. + // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. + // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); + static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); + using TensorPageOffset = decltype(make_tensor>(Shape>{})); + using TensorKVPtr = decltype(make_tensor(Shape>{})); + + GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; + cutlass::FastDivmod const &page_size_divmod; + cutlass::FastDivmod const &blockN_per_page_size_divmod; + int const thread_idx; + int const seqlen_k; + int const leftpad_k; + int const* const ptr_page_table; + GmemThrCopyKVCpAsync const gmem_thr_copy_kv; + TensorPageTable mPageTable; + TensorKV mK_paged, mV_paged; + TensortKpK tKpK; + TensortVpV tVpV; + TensorPageOffset tPrPageOffset; + TensorKVPtr tPrVPtr; + int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA + + CUTLASS_DEVICE + PagedKVManager(int const* const ptr_page_table_, + ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, + Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, + Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, + cutlass::FastDivmod const &page_size_divmod, + cutlass::FastDivmod const &blockN_per_page_size_divmod, + int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, + int bidb_kv_idx + ) + : page_size_divmod(page_size_divmod) + , blockN_per_page_size_divmod(blockN_per_page_size_divmod) + , thread_idx(thread_idx) + , seqlen_k(seqlen_k) + , leftpad_k(leftpad_k) + , ptr_page_table(ptr_page_table_) + , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) + , bidb_kv_idx(bidb_kv_idx) + , bidb_kv_idx_prev(bidb_kv_idx) + + { + mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); + mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); + auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); + tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + #pragma unroll + for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + #pragma unroll + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } + tVpV = cute::conditional_return(tKpK, tVpV_); + }; + + template + CUTLASS_DEVICE + void load_page_table(const int n_block) { + // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries + // it needs, and we don't need any sync between warps. + // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by + // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc. + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow); + int const row_idx = n_block * kBlockN + row; + int page_idx, page_offset; + page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k); + // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row + // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0. + int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0; + tPrPageOffset[i] = {page, page_offset}; + // if (cute::thread0()) { printf("row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\n", row, page_idx, page_offset, page, leftpad_k, seqlen_k); } + } + if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } + }; + + template + CUTLASS_DEVICE + void load_page_table_TMA(const int n_block) { + // We require that page size is a multiple of kBlockN, and there's no leftpad_k + if (ptr_page_table) { + bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; + } else { + n_block_idx = n_block; + } + if constexpr (First_iter && !KV_Same_Iter) { + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + } + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_K_TMA() { + return {n_block_idx, bidb_kv_idx}; + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_V_TMA() { + if constexpr (KV_Same_Iter) { + return {n_block_idx, bidb_kv_idx}; + } else { + cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + return indices; + } + }; + + CUTLASS_DEVICE + TensorKVPtr compute_K_ptr() { + Tensor tPrKPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + auto [page, page_offset] = tPrPageOffset[i]; + tPrKPtr[i] = &mK_paged(page_offset, _0{}, page); + } + return tPrKPtr; + }; + + CUTLASS_DEVICE + void compute_V_ptr() { + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + auto [page, page_offset] = tPrPageOffset[i]; + tPrVPtr[i] = &mV_paged(page_offset, _0{}, page); + } + }; + + template + CUTLASS_DEVICE + void load_K(const int n_block, TensorK &&sK) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; + + Tensor tPrKPtr = compute_K_ptr(); + + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + // We want to use the row indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) + int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN))); + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + bool const should_load = EvenN + ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit) + : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element const* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k)); + } + } // Don't need to clear out the rest of the smem since we'll mask out the scores anyway + } + }; + + template + CUTLASS_DEVICE + void load_V(const int n_block, TensorV &&sV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; + + if constexpr (KV_Same_Iter) { compute_V_ptr(); } + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); + + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // Faster to rely on the cp.async to clear smem that are out of bound, + // rather than calling cute::clear directly. + // We have to be careful not to write to smem past `kBlockN` if !EvenN. + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; + Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + } + } + } + if constexpr (!KV_Same_Iter) { compute_V_ptr(); } + }; + + template + CUTLASS_DEVICE + void store_K(const int n_block, TensorK &&tKrK) { + Tensor tPrKPtr = compute_K_ptr(); + // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading) + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + GmemTiledCopyKVStore gmem_tiled_copy_kv_store; + // We want to use the row indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) + // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + // if (threadIdx.x == 128) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); } + #pragma unroll + for (int m = 0; m < size<1>(tKrK); ++m) { + bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tKrK); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tKpK(_0{}, k)) { + cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void store_V(const int n_block, TensorV &&tVrV) { + if constexpr (KV_Same_Iter) { compute_V_ptr(); } + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); + + GmemTiledCopyKVStore gmem_tiled_copy_kv_store; + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVrV); ++m) { + bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; + Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tVrV); ++k) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tVpV(_0{}, k)) { + cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); + } + } + } + } + if constexpr (!KV_Same_Iter) { compute_V_ptr(); } + }; + + +}; + +} // namespace flash diff --git a/flash-attn/rotary.h b/flash-attn/rotary.h new file mode 100644 index 0000000000000000000000000000000000000000..aa3602cc79557d720d3bdd3c500be0e895a88ba5 --- /dev/null +++ b/flash-attn/rotary.h @@ -0,0 +1,489 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void +apply_rotary_interleaved(Tensor &rK, + Tensor const &rCos, + Tensor const &rSin) { + CUTE_STATIC_ASSERT_V(rank(rK) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); + CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); + static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2); + static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor K_fp32 = make_tensor_like(rK); + convert_type_out(rK, K_fp32); + Tensor cos_fp32 = make_tensor_like(rCos); + convert_type_out(rCos, cos_fp32); + Tensor sin_fp32 = make_tensor_like(rSin); + convert_type_out(rSin, sin_fp32); + #pragma unroll + for (int i = 0; i < size<0>(K_fp32) / 2; ++i) { + float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i]; + float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i]; + K_fp32[2 * i] = real; + K_fp32[2 * i + 1] = imag; + } + convert_type_out(K_fp32, rK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void +apply_rotary_contiguous(Tensor &rK_left, + Tensor &rK_right, + Tensor const &rCos, + Tensor const &rSin) { + CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); + CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right)); + CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos)); + CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); + static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor K_left_fp32 = make_tensor_like(rK_left); + convert_type_out(rK_left, K_left_fp32); + Tensor K_right_fp32 = make_tensor_like(rK_right); + convert_type_out(rK_right, K_right_fp32); + Tensor cos_fp32 = make_tensor_like(rCos); + convert_type_out(rCos, cos_fp32); + Tensor sin_fp32 = make_tensor_like(rSin); + convert_type_out(rSin, sin_fp32); + #pragma unroll + for (int i = 0; i < size<0>(K_left_fp32); ++i) { + float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i]; + float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i]; + K_left_fp32[i] = real; + K_right_fp32[i] = imag; + } + convert_type_out(K_left_fp32, rK_left); + convert_type_out(K_right_fp32, rK_right); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rotary { + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved + // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will + // load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + + using LayoutAtom = Layout, Int>, + Stride, _1>>; + using TiledCopyQK = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + using GmemTiledCopyRotary = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 4 or 8 vals per store + using GmemTiledCopyRotaryCont = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + + using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0))); + using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0))); + using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortRpR = decltype(make_tensor(make_shape(size<2>(TensortRcR{})))); + using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortRpRCont = decltype(make_tensor(make_shape(size<2>(TensortRcRCont{})))); + using TensormR = decltype(make_tensor( + make_gmem_ptr((Element const*)nullptr), + ShapeRotary{}, + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}))); + using TensortRgR = decltype( + GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor( + make_gmem_ptr((Element const*)nullptr), + make_shape(Int{}, Int{}, int(0)), + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); + using TensortRgRCont = decltype( + GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor( + make_gmem_ptr((Element const*)nullptr), + make_shape(Int{}, Int{}, int(0)), + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); + + GmemTiledCopyRotary gmem_tiled_copy_rotary; + GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont; + bool const is_rotary_interleaved; + int const rotary_dim; + int const thread_idx; + int const max_seqlen; + GmemThrCopyRotary const gmem_thr_copy_rotary; + GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont; + TensortRpR tRpR; + TensortRpRCont tRpRCont; + TensormR mCos, mSin; + TensortRgR tRgCos, tRgSin; + TensortRgRCont tRgCosCont, tRgSinCont; + + CUTLASS_DEVICE + Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_, + Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_, + bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx) + : is_rotary_interleaved(is_rotary_interleaved) + , rotary_dim(get<1>(shape_rotary) * 2) + , thread_idx(thread_idx) + , max_seqlen(max_seqlen) + , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx)) + , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx)) + + { + auto stride_rotary_cos = make_stride(cute::conditional_return(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_)); + auto stride_rotary_sin = make_stride(cute::conditional_return(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_)); + mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos); + mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin); + Tensor gCos = local_tile(mCos, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) + Tensor gSin = local_tile(mSin, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) + tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos); + tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin); + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR); + tRpR = make_tensor(make_shape(size<2>(tRcR))); + #pragma unroll + for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); } + Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR); + tRpRCont = make_tensor(make_shape(size<2>(tRcRCont))); + #pragma unroll + for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); } + }; + + template + CUTLASS_DEVICE + auto load_cos_sin(int const block) { + using GmemTiledCopyRo = std::conditional_t; + auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); + Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); + Tensor tRgCosCur = cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, block); + Tensor tRgSinCur = cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, block); + // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way + Tensor tRrCos = make_tensor_like(tRgCosCur); + Tensor tRrSin = make_tensor_like(tRgSinCur); + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); + // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens + #pragma unroll + for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) { + if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tRrCos); ++k) { + if (tRpRCur(k)) { + cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k)); + cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k)); + } + } + } + } + return cute::make_tuple(tRrCos, tRrSin);; + } + + template + CUTLASS_DEVICE + auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) { + static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad; + using GmemTiledCopyRo = std::conditional_t; + auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); + Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); + // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way + Tensor tRrCos = make_tensor_like(cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, _0{})); + Tensor tRrSin = make_tensor_like(cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, _0{})); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); + + // The main bottleneck here is actually instruction cache misses. + + // Similar to PagedKVNonTMA, it's expensive to compute the pointers. + // We split the work among threads loading the same row, then __shfl_sync the pointers. + static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); + Tensor tPrCosPtr = make_tensor(Shape>{}); + Tensor tPrSinPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < NumPtrPerThread; ++i) { + int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{})); + int const idx = block * kBlockMN + row; + int row_actual = qhead_per_khead_divmod.divide(idx); + tPrCosPtr[i] = &mCos(row_actual, _0{}); + tPrSinPtr[i] = &mSin(row_actual, _0{}); + } + + #pragma unroll + for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) { + int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{})); + Element const* cos_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Element const* sin_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < max_seqlen * qhead_per_khead) { + Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape>{}), + Shape>{}); + Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape>{}), + Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tRgCos); ++k) { + int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur); + if (tRpRCur(k)) { + cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k)); + cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k)); + } + } + } + } + return cute::make_tuple(tRrCos, tRrSin); + } + + template + CUTLASS_DEVICE + void + apply_Q_interleaved(TensorsQ &sQ, // (kBlockM, kHeadDim) + TensortRrR const &tRrCos, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary + TensortRrR const &tRrSin, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary + int const m_block, int const qhead_per_khead=1) + { + TiledCopyQK tiled_copy_q; + auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); + Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ); + Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); + static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2); + static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + if (tRpR(k)) { + Tensor rQ = make_fragment_like(tQsQ(_, m, k)); + cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ); + apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k)); + cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_Q_contiguous(TensorsQ &sQ, // (kBlockM, kHeadDim) + TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont + TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont + int const m_block, int const qhead_per_khead=1) + { + TiledCopyQK tiled_copy_q; + auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); + Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int>{}); + Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont)); + static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + + #pragma unroll + for (int m = 0; m < size<1>(tQcQ); ++m) { + int const row = get<0>(tQcQ(_0{}, m, _0{})); + if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tQcQ); ++k) { + int const col = get<1>(tQcQ(_0{}, _0{}, k)); + if (col < rotary_dim / 2) { + int const col_idx_left = col / kGmemElemsPerLoad; + int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad); + Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left); + Tensor rQ_right = make_fragment_like(rQ_left); + cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right); + apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); + cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) + TensorgK &gK, // (kBlockN, kHeadDim) + TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV + TensortRrR const &tRrCos, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary + TensortRrR const &tRrSin, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary + TensorKPtr const &tPrKPtr, + int const n_block) + { + TiledCopyQK tiled_copy_k; + auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); + Tensor tKsK = gmem_thr_copy_q.partition_S(sK); + Tensor tKgK = gmem_thr_copy_q.partition_S(gK); + Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); + static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); + static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + if constexpr (PagedKVNonTMA) { + static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); + } + + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + int const row = get<0>(tKcK(_0{}, m, _0{})); + auto mK_cur_copy = [&] { + if constexpr (PagedKVNonTMA) { + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + return cute::tiled_divide(mK_cur, Shape>{}); + } else { + return nullptr; + } + }(); + if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + if (tKpK(k)) { + Tensor rK = make_fragment_like(tKsK(_, m, k)); + cute::copy(tiled_copy_k, tKsK(_, m, k), rK); + if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } + if constexpr (!PagedKVNonTMA) { + cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); + } else { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki)); + } + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) + TensorgK &gK, // (kBlockN, kHeadDim) + TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV + TensortRrR const &tRrCosCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont + TensortRrR const &tRrSinCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont + TensorKPtr const &tPrKPtr, + int const n_block, int const max_k) + { + TiledCopyQK tiled_copy_k; + auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); + Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int>{}); + Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int>{}); + Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); + static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + if constexpr (PagedKVNonTMA) { + static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); + } + + const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad; + const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad; + #pragma unroll + for (int m = 0; m < size<1>(tKcK); ++m) { + int const row = get<0>(tKcK(_0{}, m, _0{})); + Tensor gK_cur_copy = [&] { + if constexpr (PagedKVNonTMA) { + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + return cute::tiled_divide(mK_cur, Shape>{}); + } else { + return gK_copy(_, row, _); + } + }(); + if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tKcK); ++k) { + if (tKpK(k)) { + int const col = get<1>(tKcK(_0{}, _0{}, k)); + bool rotate = col < rotary_dim / 2; + int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad; + int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2); + Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left); + Tensor rK_right = make_fragment_like(rK_left); + cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right); + if (rotate) { + apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); + } + cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left)); + if (col_idx_right * kGmemElemsPerLoad < max_k) { + cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right)); + } + } + } + } + } + }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/flash-attn/seqlen.h b/flash-attn/seqlen.h new file mode 100644 index 0000000000000000000000000000000000000000..5547238b34892248ac2e2b2c464556136e304468 --- /dev/null +++ b/flash-attn/seqlen.h @@ -0,0 +1,95 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +// We consolidate all the info related to sequence length here. This is so that we can do all +// the gmem reads once at the beginning of each tile, rather than having to repeat these reads +// to compute various things like n_block_min, n_block_max, etc. + +template +struct SeqlenInfo { + + int const offset, offset_padded; + int const seqlen; + + CUTLASS_DEVICE + SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused) + : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb]) + , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock) + , seqlen(!Varlen + ? seqlen_static + : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static))) + { + } + +}; + +template +struct SeqlenInfoQK { + + int const offset_q, offset_k, offset_q_padded; + int const seqlen_q, seqlen_k; + + CUTLASS_DEVICE + SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, + int const* const seqused_q, int const* const seqused_k + ) + : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) + , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb]) + // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch + // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence. + // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM + // However, the start must align to multiples of kBlockM. + , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM) + , seqlen_q(!Varlen + ? seqlen_q_static + : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) + , seqlen_k(!Varlen + ? seqlen_k_static + : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static))) + { + } + +}; + +template +struct SeqlenInfoQKNewK { + + static_assert(!(AppendKV && !Varlen), "AppendKV is only supported with Varlen"); + + int const leftpad_k; + int const offset_q, offset_k, offset_k_new; + int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; + + CUTLASS_DEVICE + SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, + int const* const seqlens_rotary + ) + : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) + , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) + , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k) + , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb]) + , seqlen_q(!Varlen + ? seqlen_q_static + : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) + , seqlen_k_og(!Varlen + ? seqlen_k_static + : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k) + , seqlen_k_new(!AppendKV + ? 0 + : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) + , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) + , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) + { + } + +}; + +} // namespace flash diff --git a/flash-attn/sm90_pipeline_no_cluster.hpp b/flash-attn/sm90_pipeline_no_cluster.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65a3d1554b362260207f295630433a95b7e5ac39 --- /dev/null +++ b/flash-attn/sm90_pipeline_no_cluster.hpp @@ -0,0 +1,99 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace cutlass { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads +// signaling the barrier during consumer_release. This causes a perf regression in FA3 +// forward pass (especially hdim 128 causal). We instead reimplement the version of +// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. +// +// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +template > +class PipelineTmaAsyncNoCluster: public Base { +public: + using FullBarrier = typename Base::FullBarrier; + using EmptyBarrier = typename Base::EmptyBarrier; + static constexpr uint32_t Stages = Stages_; + using PipelineState = typename Base::PipelineState; + + using SharedStorage = typename Base::SharedStorage; + using ThreadCategory = typename Base::ThreadCategory; + using Params = typename Base::Params; + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + + static_assert(cute::is_same_v || cute::is_same_v); + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params); + } + + } + + // Constructor + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } + + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +private: + EmptyBarrier* const empty_barrier_ptr_ = nullptr; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & (!skip) /*is_signaling_thread*/); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass diff --git a/flash-attn/softmax.h b/flash-attn/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..8fcdb6bd07083649fcfc1c78202af194aa237f99 --- /dev/null +++ b/flash-attn/softmax.h @@ -0,0 +1,170 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ni++) { + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256]. + // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow. + static constexpr float max_offset = float(Max_offset); // We can only template on int, not float + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = Check_inf + ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset) + : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)). This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + float const softmax_scale_log2; + + CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {}; + + template + __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + cute::fill(scores_scale, 1.f); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + } + return scores_scale; + }; + + template + __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + }; + + __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; + #pragma unroll + for (int mi = 0; mi < size(row_sum); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; + scores_scale(mi) = inv_sum * final_scale; + // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. + if constexpr (Max_offset != 0) { + static constexpr float sum_scale = 1.f / float(1 << Max_offset); + sum *= sum_scale; + } + row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); + } + return scores_scale; + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } + } + }; + +}; + +} // namespace flash diff --git a/flash-attn/static_switch.h b/flash-attn/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..5e13b5f93a8604f3d6bb24fc57b6c3bbc0b11509 --- /dev/null +++ b/flash-attn/static_switch.h @@ -0,0 +1,181 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +// + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ + [&] { \ + constexpr static bool LOCAL_CONST_NAME = false; \ + if (CAUSAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() +#else + #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ + [&] { \ + if (CAUSAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = true; \ + constexpr static bool LOCAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } else if (LOCAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + constexpr static bool LOCAL_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + constexpr static bool LOCAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() +#endif + +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_PAGEDKV + #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define PAGEDKV_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SPLIT + #define SPLIT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SPLIT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_APPENDKV + #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define APPENDKV_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_PACKGQA + #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define PACKGQA_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_VARLEN + #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define VARLEN_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_CLUSTER + #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define CLUSTER_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SM8x + #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ + [&] { \ + constexpr static int ARCH_NAME = 90; \ + return __VA_ARGS__(); \ + }() +#else + #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ + [&] { \ + if (ARCH == 86 || ARCH == 89) { \ + constexpr static int ARCH_NAME = 86; \ + return __VA_ARGS__(); \ + } else if (ARCH < 90) { \ + constexpr static int ARCH_NAME = 80; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int ARCH_NAME = 90; \ + return __VA_ARGS__(); \ + } \ + }() +#endif + +#ifndef FLASHATTENTION_ENABLE_VCOLMAJOR + #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define VCOLMAJOR_SWITCH BOOL_SWITCH +#endif + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int kHeadSize = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int kHeadSize = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int kHeadSize = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int kHeadSize = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 256) { \ + constexpr static int kHeadSize = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/flash-attn/tile_scheduler.hpp b/flash-attn/tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f90f66adc2177c872abdf7fc589e452b5e5586f --- /dev/null +++ b/flash-attn/tile_scheduler.hpp @@ -0,0 +1,709 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/arch/barrier.h" + +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +/////////////////////////////////////////////////////////////////////////////// + +// Host side kernel arguments +struct TileSchedulerArguments { + // num_head is num_head_q if not PackGQA, else num_head_k + int const num_blocks, num_head, num_batch, num_splits; + int const qhead_per_khead; + int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr + int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling + int* const tile_count_semaphore = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + // int const* const num_m_blocks_ptr = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class SingleTileScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int const num_blocks, num_head, num_batch, num_splits; + int const qhead_per_khead; + int const seqlen; + cutlass::FastDivmod nsplits_divmod; + int const* const cu_seqlens; + int const* const seqused; + int const* const num_splits_dynamic_ptr = nullptr; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); + assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits + return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, + args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(!Split ? 1 : args.num_splits), + !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, + args.num_splits_dynamic_ptr}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; + } + + struct WorkTileInfo { + int block_idx = 0; + int bidh = 0; + int bidb = 0; + int split_idx = 0; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return bidb >= 0; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; + } + + }; + + CUTLASS_DEVICE + SingleTileScheduler(SharedStorage* const smem_scheduler) { } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; + if constexpr (Split) { + int split_idx; + work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); + work_info.split_idx = split_idx; + } + bool is_valid_tile = true; + if constexpr (Varlen) { + int seqlen = params.seqused + ? params.seqused[work_info.bidb] + : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + is_valid_tile = work_info.block_idx * kBlock < seqlen; + } + if constexpr (Varlen && Split) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + is_valid_tile &= work_info.split_idx < num_splits_dynamic; + // Use the top 16 bits to store num_splits + work_info.split_idx |= (num_splits_dynamic << 16); + } + work_info.bidb = is_valid_tile ? work_info.bidb : -1; + return work_info; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {0, 0, -1, 0}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class StaticPersistentTileScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + cutlass::FastDivmod nsplits_divmod; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), + cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); + int split_idx = 0; + if constexpr (Split) { + bidh = params.nsplits_divmod.divmod(split_idx, bidh); + } + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class DynamicPersistentTileScheduler { + + // This scheduler targets the causal (or local) case where each tile takes different + // amount of time. We use longest-processing-time-first scheduling: + // the longest remaining tile is assigned to the first SM that's free. + // SM indicates they are free by incrementing a semaphore. + // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling + // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads. + // This is the L2 swizzling part. The size of each section is precomputed based on the + // size of K & V and the L2 cache size. + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + +public: + using SharedStorage = int; + +protected: + SharedStorage* const tile_count_smem; + +public: + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; + cutlass::FastDivmod const l2_minor_residual_divmod; + int const num_hb_quotient; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; + int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V + // Swizzle is the size of each "section". Round swizzle to a power of 2 + // If not PackGQA already, the size of each section can increase by qhead_per_khead + // Need to be careful about the case where only one head will fit + auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // Seems faster if swizzle if a power of 2 + int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; + int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); + // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder); + assert(args.tile_count_semaphore != nullptr); + return {num_split_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), + // don't divide by 0 + cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), + (args.num_head * args.num_batch) / swizzle, + args.tile_count_semaphore}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + int split_idx = 0; + if constexpr (Split) { + split_idx = params.m_block_divmod.divmod(block, block); + } + // Longest-processing-time-first + block = params.m_block_divmod.divisor - 1 - block; + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const { + if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty + } + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 + int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty + if (threadIdx.x % NumProducerThreads == 0) { + *tile_count_smem = current_work.tile_idx; + } + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull + return {new_tile_idx}; + } else { + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull + int tile_idx = *tile_count_smem; + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty + return {tile_idx}; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +class SingleTileBwdLPTScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; + cutlass::FastDivmod const l2_minor_residual_divmod; + int const num_hb_quotient; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k + int const size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; + int const size_one_dqaccum_head = args.seqlen_k * args.headdim * sizeof(float); + int const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum + // Swizzle is the size of each "section". Round swizzle to a power of 2 + // Need to be careful about the case where only one head will fit + auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // Seems faster if swizzle if a power of 2 + int const swizzle = size_l2 < size_one_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_head)); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; + // printf("num_blocks = %d, num_head = %d, num_batch = %d, size_one_head = %d, ratio = %d, swizzle = %d, num_hb_remainder = %d\n", args.num_blocks, args.num_head, args.num_batch, size_one_head, size_l2 / size_one_head, swizzle, num_hb_remainder); + assert(args.tile_count_semaphore != nullptr); + return {args.num_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), + // don't divide by 0 + cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), + (args.num_head * args.num_batch) / swizzle}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(params.total_blocks)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + return {block, bidh, bidb, 0 /*split_idx*/}; + } + + }; + + CUTLASS_DEVICE + SingleTileBwdLPTScheduler(SharedStorage* const smem_scheduler) { } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {params.total_blocks}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class VarlenDynamicPersistentTileScheduler { + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + +public: + using SharedStorage = int4; + +protected: + SharedStorage* const work_info_smem; + +public: + + // Device side kernel params + struct Params { + int num_head, num_batch; + int const qhead_per_khead; + int const seqlen; + cutlass::FastDivmod head_divmod; + cutlass::FastDivmod nsplits_divmod; + int* const tile_count_semaphore; + int const* const cu_seqlens; + int const* const seqused; + // int* const num_m_blocks_ptr; + int const* const num_splits_dynamic_ptr; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // If Split, for the purpose of scheduling, we pretend that instead there are + // (args.num_splits * args.num_head) number of heads. + assert(args.tile_count_semaphore != nullptr); + assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + return {args.num_head, args.num_batch, + args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(!Split ? 1 : args.num_splits), + args.tile_count_semaphore, args.cu_seqlens, args.seqused, + // args.num_m_blocks_ptr, + args.num_splits_dynamic_ptr}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx, block, bidh, bidb; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } + return bidb < params.num_batch; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + if constexpr (!Split) { + return {block, bidh, bidb, 0 /*split_idx*/}; + } else { + // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh); + uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; + int bidh_actual = reinterpret_cast(bidh_actual_u); + // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx + uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); + int split_idx = reinterpret_cast(split_idx_u); + // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // if (threadIdx.x == 128) { + // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); + // } + return {block, bidh_actual, bidb, split_idx}; + } + } + }; + + CUTLASS_DEVICE + VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; + + CUTLASS_DEVICE + WorkTileInfo + tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + auto get_num_m_blocks = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlock) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlock) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; + }; + + auto get_num_splits = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? (!Split ? 1 : (params.num_splits_dynamic_ptr + ? params.num_splits_dynamic_ptr[batch_idx] + : params.nsplits_divmod.divisor)) + : 0; + }; + + int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane + int num_splits = get_num_splits(current_work.bidb); + int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + // Cumulative number of blocks for the next 31 batches + int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); + // Total number of blocks for the next 31 batches + int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); + // Only the lower 16 bits are the actual bidh + int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + if constexpr (Split) { + int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + } + int bidb = current_work.bidb; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); + // } + // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } + while (group_end_tile <= next_tile_idx) { + bidb += cutlass::NumThreadsPerWarp - 1; + if (bidb >= params.num_batch) { + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + return {next_tile_idx, 0, 0, params.num_batch}; + } + num_m_blocks = get_num_m_blocks(bidb); + num_splits = get_num_splits(bidb); + num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); + m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); + group_end_tile += m_blocks_in_group * params.num_head; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + } + int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); + // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } + bidb += batch_idx_in_group; + num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); + if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } + int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int bidh = mh_block / num_m_blocks; + int block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); + // } + return {next_tile_idx, block, bidh, bidb}; + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + if constexpr (IsProducerWarp) { + WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); + } + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull + return work_info; + } else { + return get_next_work(params, {0, 0, 0, 0}); + } + } + + CUTLASS_DEVICE + void + init_consumer() const { + // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 + int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; + work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); + } + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull + return work_info; + } else { + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull + int4 work_info = *work_info_smem; + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty + return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // flash diff --git a/flash-attn/tile_size.h b/flash-attn/tile_size.h new file mode 100644 index 0000000000000000000000000000000000000000..e6cb31515c731dca3f3ea3579ef9ce9a80ffcfdf --- /dev/null +++ b/flash-attn/tile_size.h @@ -0,0 +1,78 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} +constexpr std::tuple tile_size_fwd_sm90( + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { + if (element_size == 2) { + if (headdim <= 64) { + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; + // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why + // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 + if (headdim_v == 512) { + return {64, 64, false, false}; + } else if (headdim_v == 256) { + return {128, 96, true, false}; + } else { + // Switch to tile size 192 x 192 for now + bool const use_blockN_128 = is_causal || is_local; + return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; + } + // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen + // return {192, is_causal || is_local ? 192 : 176, true, false}; + } else if (headdim <= 96) { + return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; + } else if (headdim <= 128) { + return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + // {128, 192, false, false} and {192, 128, false, true} are quite good too + // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS + } else if (headdim <= 192) { + return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem + } else { + return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem + } + } else { + if (headdim <= 64) { + return {192, 160, true, true}; + } else if (headdim <= 96) { + return {192, 128, true, true}; + } else if (headdim <= 128) { + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + } else if (headdim <= 192) { + return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; + } else { + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap + } + } +} + +// Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} +constexpr std::tuple tile_size_fwd_sm8x( + bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, + bool paged_kv=false, bool varlen_and_split=false, + bool softcap=false, bool append_kv=false) { + if (element_size == 2) { + if (headdim <= 64) { + return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false}; + } else if (headdim <= 96) { + return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false}; + } else if (headdim <= 128) { + bool const use_8_warps = sm86_or_89 | varlen_and_split; + return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps}; + } else if (headdim <= 192) { + bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv; + return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64}; + } else { + return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv}; + } + } else { + // Placeholder for now + return {128, 64, 8, 2, false}; + } +} diff --git a/flash-attn/utils.h b/flash-attn/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a568e3075aa3071c15a866bd4074418824fb159b --- /dev/null +++ b/flash-attn/utils.h @@ -0,0 +1,682 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +#include "cuda_check.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A wrapper for the kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55 +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm80_to_sm89 : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890) + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +int div_floor(cutlass::FastDivmod const& divmod, int dividend) { + // Take care of the negative case: https://stackoverflow.com/questions/39304681/division-with-negative-dividend-but-rounded-towards-negative-infinity + // Maybe the compiler will turn the -1 - * into bit negation operation, I haven't checked. + return dividend >= 0 ? divmod.divide(dividend) : -1 - divmod.divide(-1 - dividend); +} + +CUTLASS_HOST_DEVICE +int round_down(cutlass::FastDivmod const& divmod, int dividend) { + return div_floor(divmod, dividend) * divmod.divisor; +} + +CUTLASS_HOST_DEVICE +int round_up(cutlass::FastDivmod const& divmod, int dividend) { + return div_floor(divmod, dividend - 1) * divmod.divisor + divmod.divisor; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE auto convert_type_unsafe(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not + // inline this function, then the memory might not be valid. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE +auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { + if constexpr (A) { + return mma.partition_fragment_A(tensor0); + } else { + return mma.partition_fragment_B(tensor0); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { + if constexpr (M_slice >= 0) { + static constexpr int MMA_M = decltype(size<1>(tCrC))::value; + static_assert(M_slice < MMA_M); + // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N) + Tensor tCrC_slice = cute::logical_divide(tCrC, Shape>{})(_, make_coord(Int{}, _), _); + if constexpr (!SwapAB) { + Tensor tCrA_slice = cute::logical_divide(tCrA, Shape>{})(_, make_coord(Int{}, _), _); + gemm(tiled_mma, tCrA_slice, tCrB, tCrC_slice); + } else { + Tensor tCrB_slice = cute::logical_divide(tCrB, Shape>{})(_, make_coord(Int{}, _), _); + gemm(tiled_mma, tCrA, tCrB_slice, tCrC_slice); + } + } else { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { + if constexpr (!SwapAB) { + warpgroup_fence_operand(const_cast(tCrA)); + } else { + warpgroup_fence_operand(const_cast(tCrB)); + } + } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); + static constexpr int kMaxKIters = 16; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + // In the case of large kNumKIters, the compiler chooses to store the smem addresses + // in registers, causing spills. This loop forces the compiler to recompute the addresses. + if constexpr (kNumKIters > kMaxKIters) { + // This will always be zero, just a way to force the compiler to recompute the smem + // addresses. This results in USEL instructions. There's probably a better way to do this. + int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; + CUTLASS_PRAGMA_UNROLL + for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + if constexpr (!SwapAB) { + warpgroup_fence_operand(const_cast(tCrA)); + } else { + warpgroup_fence_operand(const_cast(tCrB)); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) { + if constexpr (SwapAB) { + gemm_sm80(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn); + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + if constexpr (!std::is_same_v) { + if (i == 0) { fn(); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + static constexpr int rA = decltype(rank(tA))::value; + static constexpr int rB = decltype(rank(tB))::value; + static constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void copy(TiledCopy const &tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + // Decay TiledCopy to CopyAtom + auto copy_atom = static_cast(tiled_copy); + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN; + if constexpr (Is_even_MN || !Clear_OOB_MN) { + if (Is_even_MN || predicate_mn) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if constexpr (Is_even_K || !Clear_OOB_K) { + if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } + } else { // Clear_OOB_K == true && Is_even_K == false + // If copy traits can be transformed with a predicate value, do it, otherwise branch here + if constexpr (has_with_bool) { + cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k)); + } else { + if (predicate_K(k)) { + cute::copy(copy_atom, S(_, m, k), D(_, m, k)); + } else { + cute::clear(D(_, m, k)); + } + } + } + } + } + } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true + if constexpr (!has_with_bool) { + if (predicate_mn) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(copy_atom, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else { + cute::clear(D(_, m, _)); + } + } else { // combine the mn predicate with the k predicate + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Byte permute and shuffle to match register layout of +// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II. +template +CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) { + // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits + static_assert(decltype(size<0, 0>(frag))::value == 4); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(decltype(stride<0, 1>(frag))::value == 4); + static_assert(sizeof(typename Fragment::value_type) == 1); + + int quad_idx = threadIdx.x % 4; + bool lane_03 = quad_idx == 0 || quad_idx == 3; + int selector_upper = lane_03 ? 0x5410 : 0x1054; + int selector_lower = lane_03 ? 0x7632 : 0x3276; + + static constexpr int upper_map[4] = {0, 3, 1, 2}; + // static constexpr int lower_map[4] = {1, 2, 0, 3}; + + Tensor frag_64b = recast(frag); // ((1, 1, 2), MMA_M, MMA_N) + #pragma unroll + for (int i = 0; i < size(frag_64b); ++i) { + uint32_t upper = frag_64b[i].x; + uint32_t lower = frag_64b[i].y; + uint32_t upper0 = lane_03 ? upper : lower; + uint32_t lower0 = lane_03 ? lower : upper; + upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); + // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); + lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4); + frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper); + frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_output_fp8(Fragment &out) { + // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(out))::value == 2); + static_assert(decltype(size<0, 1>(out))::value == 2); + static_assert(decltype(size<0, 2>(out))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(out))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag); ++mi) { + #pragma unroll + for (int j = 0; j < size<0, 1>(frag); ++j) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag) / 2; ++i) { + cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi)); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4); + + int quad_idx = threadIdx.x % 4; + bool lane_03 = quad_idx == 0 || quad_idx == 3; + + static constexpr int upper_map[4] = {0, 2, 3, 1}; + // static constexpr int lower_map[4] = {2, 0, 1, 3}; + + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } + using type2 = std::conditional_t; + Tensor frag_2 = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); } + #pragma unroll + for (int mi = 0; mi < size<1>(frag_2); ++mi) { + #pragma unroll + for (int j = 0; j < size<0, 1>(frag_2); ++j) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) { + type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi); + type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi); + type2 upper0 = lane_03 ? upper : lower; + type2 lower0 = lane_03 ? lower : upper; + upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); + // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); + lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4); + frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0; + frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0; + } + } + } + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void apply_softcap(Tensor &tensor, float const softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ + Tensor out = make_fragment_like(tensor); + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + out(i) = 1.f - (tensor(i) * tensor(i)); + } + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE T warp_prefix_sum(T val) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { + T partial_sum = __shfl_up_sync(0xffffffff, val, i); + if (lane >= i) { val += partial_sum; } + } + return val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE +int canonical_warp_group_idx_nosync() { + return threadIdx.x / cutlass::NumThreadsPerWarpGroup; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0f55756be2aa902f13fa92102260d08e73455a --- /dev/null +++ b/tests/test_flash_attn.py @@ -0,0 +1,1173 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn3 import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn3 import flash_attn_with_kvcache, get_scheduler_metadata + +from flash_attn3._ops import ops + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) +# @pytest.mark.parametrize("page_size", [None]) +@pytest.mark.parametrize("has_leftpad", [False, True]) +# @pytest.mark.parametrize("has_leftpad", [False]) +@pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False, True]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + has_qv = d == 64 and dv >= 256 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False, True] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits + ) + else: + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 8192), + ], +) +def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + torch.random.manual_seed(0) + batch_size = 2 + nheads = 16 + nheads_kv = 4 + # There was a bug where this would cause "unspecified launch failure" due to Cluster + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + for _ in range(100): + flash_attn_func(q, k, v, causal=causal) + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [80]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + (2048, 2048), + ], +) +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # Simulate under memory load + dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0, lse0 = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out0) + dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(1000): + torch.random.manual_seed(42) + out, lse = flash_attn_func(q, k, v, causal=causal) + assert torch.equal(out, out0) + assert torch.equal(lse, lse0) + + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + # breakpoint() + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, nheads, seqlen) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [128]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + if DISABLE_SPLIT: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # from flash_attn.utils.benchmark import pytorch_profiler + # # pytorch_profiler(torch.sum, lse_partial) + # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) + # pytorch_profiler(torch.sum, out_partial) + +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert ops.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert ops.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert ops.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert ops.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/torch-ext/flash_attn/__init__.py b/torch-ext/flash_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d12edd447e84d6e65c4fa9ab91ae3c3192884c1 --- /dev/null +++ b/torch-ext/flash_attn/__init__.py @@ -0,0 +1,17 @@ +from .flash_attn_interface import ( + flash_attn_combine, + flash_attn_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_with_kvcache, + get_scheduler_metadata, +) + +__all__ = [ + "flash_attn_combine", + "flash_attn_func", + "flash_attn_qkvpacked_func", + "flash_attn_varlen_func", + "flash_attn_with_kvcache", + "get_scheduler_metadata", +] diff --git a/torch-ext/flash_attn/flash_attn_interface.py b/torch-ext/flash_attn/flash_attn_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..7572297784a73fdf891af6aafee88428120b72f2 --- /dev/null +++ b/torch-ext/flash_attn/flash_attn_interface.py @@ -0,0 +1,828 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ._ops import ops as flash_attn_3_cuda + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def _flash_attn_forward( + q, + k, + v, + k_new, + v_new, + qv, + out, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + page_table, + kv_batch_idx, + leftpad_k, + rotary_cos, + rotary_sin, + seqlens_rotary, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0): + q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] + v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v + cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ + maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new) + ] + seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)] + page_table, kv_batch_idx, leftpad_k = [ + maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k) + ] + rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + seqlens_rotary = maybe_contiguous(seqlens_rotary) + out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + q, + k, + v, + k_new, + v_new, + qv, + out, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + page_table, + kv_batch_idx, + leftpad_k, + rotary_cos, + rotary_sin, + seqlens_rotary, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size[0], + window_size[1], + attention_chunk, + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + ) + return out, softmax_lse, *rest + + +def _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + sequed_q, + sequed_k, + max_seqlen_q, + max_seqlen_k, + dq, + dk, + dv, + softmax_scale, + causal, + window_size=(-1, -1), + softcap=0.0, + deterministic=False, + sm_margin=0, +): + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + sequed_q, + sequed_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + deterministic, + sm_margin, + ) + return dq, dk, dv, softmax_d + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + softmax_scale, + causal, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + deterministic=False, + num_heads_q=None, + sm_margin=0, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + if qkv.dim() == 5: + assert qkv.shape[-3] == 3 + q, k, v = qkv.unbind(dim=-3) + else: + assert qkv.dim() == 4 + assert num_heads_q is not None + num_heads_k = (qkv.shape[2] - num_heads_q) // 2 + assert num_heads_k * 2 + num_heads_q == qkv.shape[2] + q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) + out, softmax_lse, *rest = _flash_attn_forward( + q, + k, + v, + None, None, # k_new, v_new + None, # qv + None, # out + None, None, None, # cu_seqlens_q/k/k_new + None, None, # seqused_q/k + None, None, # max_seqlen_q/k + None, None, None, # page_table, kv_batch_idx, leftpad_k, + None, None, None, # rotary_cos/sin, seqlens_rotary + q_descale, k_descale, v_descale, + softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + sm_margin=sm_margin, + ) + # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.attention_chunk = attention_chunk + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.ndim = qkv.dim() + ctx.sm_margin = sm_margin + # return out, softmax_lse + return out + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" + if ctx.ndim == 5: + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + dq, dk, dv = dqkv.unbind(dim=-3) + else: + num_heads_q = q.shape[2] + num_heads_k = k.shape[2] + qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + softmax_scale, + causal, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + deterministic=False, + sm_margin=0, + ): + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) + # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( + out, softmax_lse, *rest = _flash_attn_forward( + q, + k, + v, + None, None, # k_new, v_new + qv, # qv + None, # out + None, None, None, # cu_seqlens_q/k/k_new + None, None, # seqused_q/k + None, None, # max_seqlen_q/k + None, None, None, # page_table, kv_batch_idx, leftpad_k, + None, None, None, # rotary_cos/sin, seqlens_rotary + q_descale, k_descale, v_descale, + softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + ) + # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.attention_chunk = attention_chunk + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + return out, softmax_lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + deterministic=False, + sm_margin=0, + ): + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) + # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( + out, softmax_lse, *rest = _flash_attn_forward( + q, + k, + v, + None, None, # k_new, v_new + qv, # qv + None, # out + cu_seqlens_q, + cu_seqlens_k, + None, # cu_seqlens_k_new + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + None, None, None, # page_table, kv_batch_idx, leftpad_k, + None, None, None, # rotary_cos/sin, seqlens_rotary + q_descale, k_descale, v_descale, + softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + ) + # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.attention_chunk = attention_chunk + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + return out, softmax_lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_qkvpacked_func( + qkv, + softmax_scale=None, + causal=False, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + deterministic=False, + num_heads_q=None, + sm_margin=0, +): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_kvpacked_func and flash_attn_func. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to + the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnQKVPackedFunc.apply( + qkv, + softmax_scale, + causal, + q_descale, k_descale, v_descale, + window_size, + attention_chunk, + softcap, + deterministic, + num_heads_q, + sm_margin, + ) + + +def flash_attn_func( + q, + k, + v, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + deterministic=False, + sm_margin=0, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + return FlashAttnFunc.apply( + q, + k, + v, + softmax_scale, + causal, + qv, + q_descale, k_descale, v_descale, + window_size, + attention_chunk, + softcap, + num_splits, + pack_gqa, + deterministic, + sm_margin, + ) + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + deterministic=False, + sm_margin=0, +): + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + qv, + q_descale, k_descale, v_descale, + window_size, + attention_chunk, + softcap, + num_splits, + pack_gqa, + deterministic, + sm_margin, + ) + + +def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): + return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype) + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk=0, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + out, softmax_lse, *rest = _flash_attn_forward( + q, + k_cache, + v_cache, + k, + v, + qv, + None, # out + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_seqlens, + max_seqlen_q, + None, # max_seqlen_k + page_table, + cache_batch_idx, + cache_leftpad, + rotary_cos, + rotary_sin, + rotary_seqlens, + q_descale, k_descale, v_descale, + softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + ) + # return (out, softmax_lse) if return_softmax_lse else out + return (out, softmax_lse, *rest) if return_softmax_lse else out + + +def get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk=0, + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication +): + cache_seqlens = maybe_contiguous(cache_seqlens) + if headdim_v is None: + headdim_v = headdim + scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, + qkv_dtype, + cache_seqlens, + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_leftpad, + page_size, + max_seqlen_k_new, + causal, + window_size[0], window_size[1], + attention_chunk, + has_softcap, + num_splits, + pack_gqa, + sm_margin, + ) + return scheduler_metadata diff --git a/torch-ext/pytorch_shim.h b/torch-ext/pytorch_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..f27e150c827e457a93a45a4d2784a23871a73b6b --- /dev/null +++ b/torch-ext/pytorch_shim.h @@ -0,0 +1,105 @@ +#pragma once + +#include + +/** + * Unforunately, the type signatures of the flash_attn ops are not compatible + * with the PyTorch library bindings. To get around that we use + * `make_pytorch_shim` which creates a lambda that exponses the API using + * PyTorch compatible types to the types, then converts them to the types + * expected by the flash_attn ops. This shims allows us to make minimal changes + * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch libary bindings: + * - `int` + * - `float` + * - `std::optional &` + * - `std::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const std::optional&` + * - `const std::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { return arg; } +}; + +template +using pytorch_library_compatible_type_t = \ + typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t arg) + { return pytorch_library_compatible_type::convert_from_type(arg); } + +// Map `std::optional &` -> `const std::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type &> { + using type = const std::optional&; + static std::optional& convert_from_type(const std::optional &arg) { + return const_cast&>(arg); + } +}; + +// Map `std::optional` -> +// `std::optional>` +// (NOTE: tested for `std::optional` -> `std::optional`) +template +struct pytorch_library_compatible_type> { + using type = std::optional>; + static std::optional> convert_from_type(std::optional arg) { + return arg; + } +}; + +// Map `std::optional&` -> `const std::optional&` +template<> +struct pytorch_library_compatible_type &> { + using type = const std::optional&; + static std::optional& convert_from_type( + const std::optional &arg) { + return const_cast&>( + reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template<> struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), + "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), + "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template<> struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), + "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret(*fun)(Args... args)){ + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2daccdd2a4788f0dc675c26ef1df3ce35be8ca5d --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,103 @@ +#include + +#include "pytorch_shim.h" +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); + + m.impl("fwd", &mha_fwd); + m.impl("bwd", &mha_bwd); + m.impl("fwd_combine", &mha_combine); + m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..99d99bc21cb39e65f6ef98906296ace3a8cb97da --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include + +#include + +std::tuple +mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ); + +std::tuple +mha_bwd( + at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +); + +std::tuple +mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads + std::optional out_, // batch_size x seqlen x num_heads x head_size + std::optional out_dtype_ + ); + +at::Tensor +mha_fwd_get_scheduler_metadata( + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, + at::ScalarType qkv_dtype, + at::Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + bool has_softcap, + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ); +