File size: 29,169 Bytes
eb8ddce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 |
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/fast_math.h> // For FastDivMod
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/epilogue/collective/builders/sm90_common.inl"
#include "seqlen.h"
#include "named_barrier.hpp"
#include "pack_gqa.h"
#include "utils.h"
namespace flash {
using namespace cute;
template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
struct CollectiveEpilogueFwd {
using TileShape_MNK_PV = TileShape_MNK_PV_;
using ClusterShape = ClusterShape_;
using Element = Element_;
using ElementPartial = float;
using ArchTag = ArchTag_;
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
static constexpr bool Varlen = Varlen_;
static constexpr bool PackGQA = PackGQA_;
static constexpr bool Split = Split_;
static constexpr bool Use_smem = !(Split && !Varlen);
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
static_assert(ArchTag::kMinComputeCapability >= 80);
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
static_assert(sizeof(Element) <= 2);
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
static constexpr bool LargeHeadDimV = kHeadDimV > 256;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
// We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
// in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
// we need to call divmod.
static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
// If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
Layout<Shape<_8, Int<kBlockKGmem>>,
Stride<Int<kBlockKGmem>, _1>>{}));
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
// ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
using CopyOpR2S = std::conditional_t<
ArchTag::kMinComputeCapability >= 90,
// cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
// static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
// static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
// struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
// cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
// };
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
};
using TMA_O = std::conditional_t<
Use_TMA_O,
decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
SmemLayoutOTMA{},
select<0, 1>(TileShape_MNK_PV{}),
_1{})), // no mcast for O
std::nullptr_t
>;
// Host side kernel arguments
struct Arguments {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
ElementPartial* ptr_O_partial;
StrideO const stride_O_partial;
float* ptr_LSE;
StrideLSE const stride_LSE;
float* ptr_LSE_partial;
StrideLSE const stride_LSE_partial;
int32_t const nheads_kv;
int const* cu_seqlens = nullptr;
int const* seqused = nullptr;
};
// Device side kernel params
struct Params {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
ShapeOPacked const shape_O_packed;
StrideOPacked const stride_O_packed;
ElementPartial* ptr_O_partial;
StrideO const stride_O_partial;
StrideOPacked const stride_O_partial_packed;
float* ptr_LSE;
StrideLSE const stride_LSE;
ShapeLSEPacked const shape_LSE_packed;
StrideLSEPacked const stride_LSE_packed;
float* ptr_LSE_partial;
StrideLSE const stride_LSE_partial;
StrideLSEPacked const stride_LSE_partial_packed;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_O tma_store_O;
int const* cu_seqlens = nullptr;
int const* seqused = nullptr;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
TMA_O tma_store_O = [&]{
if constexpr (Use_TMA_O) {
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
} else {
return nullptr;
}
}();
// If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
auto const shape_O_packed = cute::conditional_return<!PackGQA>(
args.shape_O,
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
);
auto const stride_O_packed = cute::conditional_return<!PackGQA>(
args.stride_O,
make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
);
auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
args.stride_O_partial,
make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
);
// If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
select<0, 2, 3, 4>(args.shape_O),
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
);
auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
args.stride_LSE,
make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
);
auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
args.stride_LSE_partial,
make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
);
return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
cutlass::FastDivmod(qhead_per_khead),
tma_store_O, args.cu_seqlens, args.seqused};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
if constexpr (Use_TMA_O) {
cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
}
}
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& params,
FrgTensorO& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
) {
auto [m_block, bidh, bidb, split_idx] = block_coord;
int num_splits = get<4>(params.shape_O_packed);
if constexpr (Split && Varlen) {
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
}
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
// Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
// If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
// Otherwise we can permute after conversion.
if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
Tensor tOrO_out = make_tensor_like<Element>(tOrO);
flash::convert_type_out(tOrO, tOrO_out);
if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
// Make sure all WGs have finished reading V
// Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
// all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
// cp.async if we need).
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
// Step 1: Write O from rmem -> smem
if constexpr (Use_smem) {
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
if constexpr (Use_TMA_O) {
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
} else {
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
} else {
if constexpr (ArchTag::kMinComputeCapability >= 90) {
#pragma unroll
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
shared_storage.pipelines.barrier_O.arrive(cta_id);
}
}
}
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
bool is_varlen = Varlen && params.cu_seqlens;
int offset_o = seqlen_info.offset;
int seqlen_o = seqlen_info.seqlen;
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
// Step 2: Write LSE from rmem -> gmem
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
// (MMA,MMA_M,MMA_K)
Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
params.shape_LSE_packed,
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
// if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
if (!LargeHeadDimV || warp_group_idx == 0) {
if constexpr (!PackGQA) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
}
} else {
PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
}
}
// Step 3: Write O from smem -> gmem
if constexpr (Use_TMA_O) {
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
if (cute::elect_one_sync()) {
cute::copy(params.tma_store_O, tOsO, tOgO);
tma_store_arrive();
tma_store_wait<0>();
#pragma unroll
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
shared_storage.pipelines.barrier_O.arrive(cta_id);
}
}
}
} else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
if (!is_split) {
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
// if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
// Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOrO = make_fragment_like(tOsO);
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
if constexpr (ArchTag::kMinComputeCapability >= 90) {
cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
#pragma unroll
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
shared_storage.pipelines.barrier_O.arrive(cta_id);
}
}
if constexpr (!PackGQA) {
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
);
} else {
// If PackGQA, we split the work of compute O_ptr among threads in the same row
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
}
} else {
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
// We already arrived on barrier_O earlier if !Use_smem
if constexpr (Use_smem) {
if constexpr (ArchTag::kMinComputeCapability >= 90) {
#pragma unroll
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
shared_storage.pipelines.barrier_O.arrive(cta_id);
}
}
}
if constexpr (!PackGQA) {
static constexpr int kGmemElemsPerStoreDirect = 2;
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
// Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
Tensor tOgO = thread_mma.partition_C(gOpartial);
Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
#pragma unroll
for (int m = 0; m < size(taccOcO_row); ++m) {
if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
#pragma unroll
for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
}
}
}
}
} else {
PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
}
}
}
}
CUTLASS_DEVICE void
store_tail() {
// Don't need to do tma_store_wait<0>() here since we already did in @store
}
// Write 0 to output and -inf to LSE
CUTLASS_DEVICE void
store_zero(
Params const& params,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
) {
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
auto [m_block, bidh, bidb, split_idx] = block_coord;
int num_splits = get<4>(params.shape_O_packed);
if constexpr (Split && Varlen) {
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
}
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
bool const is_varlen = Varlen && params.cu_seqlens;
int offset_o = seqlen_info.offset;
int seqlen_o = seqlen_info.seqlen;
int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
params.shape_LSE_packed,
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
static_assert(kBlockM <= NumEpilogueThreads);
if (thread_idx < kBlockM) {
const int row = m_block * kBlockM + thread_idx;
if constexpr (!PackGQA) {
if (row < seqlen_o) { mLSE(row) = -INFINITY; }
} else {
if (row < seqlen_o * qhead_per_khead) {
int m_idx, h_idx;
m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
// mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
}
}
}
// If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
// since it will not use the value of O if LSE is -inf.
if (!is_split) {
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
if constexpr (!PackGQA) {
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_fragment_like(tOgO);
cute::clear(tOrO);
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
);
} else {
// If PackGQA, we split the work of compute O_ptr among threads in the same row
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
cute::clear(tOrO);
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
}
}
}
};
} // namespace flash
|