File size: 26,008 Bytes
eb8ddce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 |
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/tensor.hpp"
#include <cutlass/cutlass.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/kernel_hardware_info.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/arch/grid_dependency_control.h"
#include "seqlen.h"
#include "utils.h"
#include "softmax.h"
namespace flash {
using namespace cute;
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
class FlashAttnFwdSm90 {
public:
// Type Aliases
using CollectiveMainloop = CollectiveMainloop_;
using CollectiveEpilogue = CollectiveEpilogue_;
static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
static constexpr bool Is_local = CollectiveMainloop::Is_local;
static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
static constexpr bool Varlen = CollectiveMainloop::Varlen;
static constexpr bool Split = CollectiveMainloop::Split;
static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
static constexpr bool HasQv = CollectiveMainloop::HasQv;
static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
// Mainloop derived types
using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ClusterShape = typename CollectiveMainloop::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
// Epilogue derived types
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
using TileScheduler = TileScheduler_;
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
/// Register requirement for Load and Math WGs
// If we use cp.async to load K and V, we need more registers for the producer WG.
static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
// If you want to print from the producer warp, you'd need to increase the number of registers
// Otherwise you'll get CUDA error.
// static constexpr uint32_t LoadRegisterRequirement = 40;
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
// Kernel level shared memory storage
// We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
// and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _1> {
union {
struct {
cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
typename CollectiveMainloop::TensorStorage mainloop;
};
// We want smem_o to line up with the start of smem_v
typename CollectiveEpilogue::TensorStorage epilogue;
};
} tensors;
struct PipelineStorage : cute::aligned_struct<16, _1> {
alignas(16) BarrierQ barrier_Q;
alignas(16) BarrierQ barrier_Qv;
alignas(16) cutlass::arch::ClusterBarrier barrier_O;
alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
cutlass::KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
MainloopParams mainloop{};
EpilogueParams epilogue{};
cutlass::KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments(Arguments const& args) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
return {
CollectiveMainloop::to_underlying_arguments(args.mainloop),
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
hw_info,
TileScheduler::to_underlying_arguments(args.scheduler)
};
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3
get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
using PipelineState = typename CollectiveMainloop::PipelineState;
using PipelineParamsK = typename MainloopPipelineK::Params;
using PipelineParamsV = typename MainloopPipelineV::Params;
using PipelineParamsVt = typename MainloopPipelineVt::Params;
using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_idx == 0 && lane_predicate) {
shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
if constexpr (HasQv) {
shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
}
shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
PipelineParamsK pipeline_params_k;
pipeline_params_k.role = warp_group_idx == 0
? MainloopPipelineK::ThreadCategory::Producer
: MainloopPipelineK::ThreadCategory::Consumer;
if constexpr (Use_TMA_KV) {
pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
pipeline_params_k.is_leader = warp_group_thread_idx == 0;
pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
} else {
pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
pipeline_params_k.producer_arv_count = NumProducerThreads;
}
static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
PipelineParamsVt pipeline_params_vt = pipeline_params_k;
if constexpr (Use_TMA_KV && !SameHeadDim) {
pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
} else {
if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
}
MainloopPipelineK pipeline_k = [&] {
if constexpr (Use_TMA_KV) {
return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
} else {
return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
}
}();
// MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
MainloopPipelineV pipeline_v = [&] {
if constexpr (!Transpose_V) {
static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
if constexpr (Use_TMA_KV) {
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
} else {
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
}
} else {
PipelineParamsV pipeline_params_v;
pipeline_params_v.role = warp_group_idx == 0
? MainloopPipelineV::ThreadCategory::Producer
: MainloopPipelineV::ThreadCategory::Consumer;
pipeline_params_v.producer_arv_count = NumProducerThreads;
pipeline_params_v.consumer_arv_count = NumMmaThreads;
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
}
}();
// If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
// the producer WG will read from pipeline_vt and write to pipeline_v.
// If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
// Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
// However, the thread role isn't used in the pipeline implementation.
MainloopPipelineVt pipeline_vt = [&] {
if constexpr (Use_TMA_KV) {
pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
} else {
pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
}
}();
PipelineParamsKVNew pipeline_params_kv_new;
pipeline_params_kv_new.role = warp_group_idx == 0
? MainloopPipelineKVNew::ThreadCategory::Producer
: MainloopPipelineKVNew::ThreadCategory::Consumer;
pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
pipeline_params_kv_new.num_consumers = NumMmaThreads;
auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
if constexpr (!SameHeadDim) {
pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
}
auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue;
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// The pipelines for AppendKV and main attention are different, since e.g. main attention
// might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
// KV_new. Since the pipeline states are different, we have to manually sync to make
// sure the two pipelines don't race when accessing smem_k and smem_v.
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
int work_idx = 0;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
if constexpr (SingleProducerWarp) {
if (warp_idx_in_warpgroup != 0) { return; }
}
if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
cutlass::arch::wait_on_dependent_grids();
// Load Q, K, V
for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
work_tile_info.is_valid(params.scheduler);
work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
SeqlenInfo_t seqlen_info{
get<2>(block_coord) /*bidb*/,
get<0>(params.mainloop.shape_Q),
!params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
get<0>(params.mainloop.shape_K_new),
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
params.mainloop.seqlens_rotary
};
if constexpr (AppendKV) {
bool tile_new_valid = mainloop.load_kv_new(
params.mainloop, pipeline_k_new, pipeline_v_new,
smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
if (tile_new_valid) {
// if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
// if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
}
}
auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
};
// pipeline_vt won't be used if we don't need to transpose V.
mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
}
mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Initialize matmul objects.
TiledMmaPV tiled_mma_pv;
PipelineState smem_pipe_read;
PipelineState smem_pipe_read_new;
// We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
scheduler.init_consumer();
mainloop.mma_init();
int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
work_tile_info.is_valid(params.scheduler);
// get_next_work will be called before the epilogue
) {
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
int const bidb = get<2>(block_coord);
SeqlenInfo_t seqlen_info{
bidb,
get<0>(params.mainloop.shape_Q),
!params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
get<0>(params.mainloop.shape_K_new),
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
params.mainloop.seqlens_rotary
};
if constexpr (AppendKV) {
bool tile_new_valid = mainloop.store_kv_new(
params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
if (tile_new_valid) {
// if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
// We need this sync so that the gmem write from the consumers is visible to the producer
// that might do TMA read after that.
asm volatile ("fence.proxy.async.global;");
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
// arrive is enough, we don't need sync. The producer will sync, which means
// after that sync we're guaranteed that the AppendKV pipeline have finished
// loading and consumer smem_k and smem_v.
// if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
}
}
// If there's tanh softcap, the scaling will be done before tanh.
float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
if constexpr (Is_FP8 && !Has_softcap) {
int const bidh = get<1>(block_coord);
int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
softmax_scale_log2 *= q_descale * k_descale;
}
flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
// Attention output (GEMM-II) accumulator.
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
bool tile_valid;
if constexpr (!LargeHeadDimV) {
tile_valid = mainloop.mma(
params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
} else { // mma_pv might not compile if !LargeHeadDimV
if (warp_group_idx == 1) {
tile_valid = mainloop.mma(
params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
} else {
tile_valid = mainloop.mma_pv(
params.mainloop, pipeline_v, smem_pipe_read,
tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
}
}
// Do this here before the epilogue so that the next tile is ready to go.
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
if constexpr (Split && Varlen) {
if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
cutlass::arch::launch_dependent_grids();
}
}
if (tile_valid) {
// if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
threadIdx.x - MmaThreadOffset, block_coord);
} else {
// Write 0 to gO and -inf to gLSE.
epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
}
}
epilogue.store_tail();
}
}
};
} // namespace flash
|