File size: 2,874 Bytes
eb8ddce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cutlass/arch/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work
// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80.
CUTLASS_DEVICE
static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) {
static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
}
CUTLASS_DEVICE
static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
}
CUTLASS_DEVICE
static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) {
static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
}
CUTLASS_DEVICE
static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class FwdNamedBarriers {
QueryEmpty = 0,
WarpSchedulerWG1 = 1,
WarpSchedulerWG2 = 2,
WarpSchedulerWG3 = 3,
AppendKV = 4,
QueryRotated = 5,
PFull = 6,
PEmpty = 7,
};
enum class BwdNamedBarriers {
KVEmpty = 0,
PdS = 1,
dQEmptyWG1 = 2,
dQEmptyWG2 = 3,
dQEmptyWG3 = 4,
dQFullWG1 = 5,
dQFullWG2 = 6,
dQFullWG3 = 7,
};
} // flash
|