File size: 11,294 Bytes
eb8ddce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
//
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
[&] { \
constexpr static bool LOCAL_CONST_NAME = false; \
if (CAUSAL_COND) { \
constexpr static bool CAUSAL_CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CAUSAL_CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#else
#define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
[&] { \
if (CAUSAL_COND) { \
constexpr static bool CAUSAL_CONST_NAME = true; \
constexpr static bool LOCAL_CONST_NAME = false; \
return __VA_ARGS__(); \
} else if (LOCAL_COND) { \
constexpr static bool CAUSAL_CONST_NAME = false; \
constexpr static bool LOCAL_CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CAUSAL_CONST_NAME = false; \
constexpr static bool LOCAL_CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_PAGEDKV
#define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define PAGEDKV_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SPLIT
#define SPLIT_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SPLIT_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_APPENDKV
#define APPENDKV_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define APPENDKV_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_PACKGQA
#define PACKGQA_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define PACKGQA_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_VARLEN
#define VARLEN_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define VARLEN_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_CLUSTER
#define CLUSTER_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define CLUSTER_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SM8x
#define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \
[&] { \
constexpr static int ARCH_NAME = 90; \
return __VA_ARGS__(); \
}()
#else
#define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \
[&] { \
if (ARCH == 86 || ARCH == 89) { \
constexpr static int ARCH_NAME = 86; \
return __VA_ARGS__(); \
} else if (ARCH < 90) { \
constexpr static int ARCH_NAME = 80; \
return __VA_ARGS__(); \
} else { \
constexpr static int ARCH_NAME = 90; \
return __VA_ARGS__(); \
} \
}()
#endif
#ifndef FLASHATTENTION_ENABLE_VCOLMAJOR
#define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define VCOLMAJOR_SWITCH BOOL_SWITCH
#endif
#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 64) { \
constexpr static int kHeadSize = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int kHeadSize = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int kHeadSize = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int kHeadSize = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int kHeadSize = 256; \
return __VA_ARGS__(); \
} \
}()
|