File size: 8,220 Bytes
eb8ddce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
struct BlockMN {
static
CUTLASS_DEVICE
cute::tuple<int, int> get_n_block_min_max(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const bidb, int const split_idx, int const num_splits,
int const window_size_left, int const window_size_right,
cutlass::FastDivmod const& attention_chunk_divmod,
cutlass::FastDivmod const& qhead_per_khead_divmod) {
int const seqlen_k = seqlen_info.seqlen_k;
int const seqlen_q = seqlen_info.seqlen_q;
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
int m_idx_max = (m_block + 1) * kBlockM;
// TODO: check off-by-1 error
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
if (Is_local && attention_chunk_divmod.divisor > 0) {
n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
}
n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN));
}
int n_block_min = 0;
if constexpr (Is_local) {
int m_idx_min = m_block * kBlockM;
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
int const n_idx = m_idx_min + seqlen_k - seqlen_q;
int n_idx_left = n_idx - window_size_left;
if (attention_chunk_divmod.divisor > 0) {
n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
}
n_block_min = std::max(int(0), n_idx_left / kBlockN);
}
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
if constexpr (Split) {
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
int split_idx_actual = split_idx & 0x0000FFFF;
int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
}
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
return {n_block_min, n_block_max};
}
static
CUTLASS_DEVICE
cute::tuple<int, int> get_n_block_k_new_min_max(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const bidb, int const split_idx, int const num_splits,
int const window_size_left, int const window_size_right,
cutlass::FastDivmod const& attention_chunk_divmod,
cutlass::FastDivmod const& qhead_per_khead_divmod) {
auto [n_block_min, n_block_max] = get_n_block_min_max(
seqlen_info, m_block, bidb, split_idx, num_splits,
window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod);
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
int const n_block_new_min = idx_k_new_min / kBlockN;
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
return {n_block_new_min, n_block_new_max};
}
static
CUTLASS_DEVICE
cute::tuple<int, int> get_m_block_min_max(
SeqlenInfo_t const& seqlen_info,
int const n_block, int const bidb,
int const window_size_left, int const window_size_right, int const sink_token_length) {
// TODO: support attention_chunk
int const seqlen_q = seqlen_info.seqlen_q;
int const seqlen_k = seqlen_info.seqlen_k;
int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
if constexpr (Is_local) {
if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
}
}
int m_block_min = 0;
if constexpr (Is_causal || Is_local) {
m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
}
return {m_block_min, m_block_max};
}
// If we have separate iterations with causal or local masking at the start, where do we stop
static
CUTLASS_DEVICE
int get_n_block_min_causal_local_mask(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const n_block_min, int const window_size_right,
cutlass::FastDivmod const& attention_chunk_divmod,
cutlass::FastDivmod const& qhead_per_khead_divmod) {
int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);
int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
if (Is_local && attention_chunk_divmod.divisor > 0) {
n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
}
return std::max(n_block_min, n_idx_right / kBlockN);
}
// If we have separate iterations with local masking at the end, where do we stop the non-masked iterations
static
CUTLASS_DEVICE
int get_n_block_min_before_local_mask(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const n_block_min, int const window_size_left,
cutlass::FastDivmod const& attention_chunk_divmod,
cutlass::FastDivmod const& qhead_per_khead_divmod) {
int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left;
if (Is_local && attention_chunk_divmod.divisor > 0) {
n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
}
return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN));
}
};
} // namespace flash
|