Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,395 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// This file provides utilities for dispatching to specialized versions of
// functions. This is especially useful for CUDA kernels, since specializing
// them to particular input sizes can often allow the compiler to unroll loops
// and place arrays into registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
//
// template<typename T, int64_t x>
// struct SquareOffset {
// static void run(T y) {
// T val = x * x + y;
// std::cout << val << std::endl;
// }
// }
//
// This function takes one compile-time argument x, and one run-time argument y.
// We might want to compile specialized versions of this for x=0, x=1, etc and
// then dispatch to the correct one based on the runtime value of x.
// One simple way to achieve this is with a lookup table:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// if (x == 0) {
// SquareOffset<T, 0>::run(y);
// } else if (x == 1) {
// SquareOffset<T, 1>::run(y);
// } else if (x == 2) {
// SquareOffset<T, 2>::run(y);
// }
// }
//
// This function takes both x and y as run-time arguments, and dispatches to
// different specialized versions of SquareOffset based on the run-time value
// of x. This works, but it's tedious and error-prone. If we want to change the
// set of x values for which we provide compile-time specializations, then we
// will need to do a lot of tedius editing of the dispatch function. Also, if we
// want to provide compile-time specializations for another function other than
// SquareOffset, we will need to duplicate the entire lookup table.
//
// To solve these problems, we can use the DispatchKernel1D function provided by
// this file instead:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// constexpr int64_t xmin = 0;
// constexpr int64_t xmax = 2;
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
// }
//
// DispatchKernel1D uses template metaprogramming to compile specialized
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
// then dispatches to the correct one based on the run-time value of x. If we
// want to change the range of x values for which SquareOffset is specialized
// at compile-time, then all we have to do is change the values of the
// compile-time constants xmin and xmax.
//
// This file also allows us to similarly dispatch functions that depend on two
// compile-time int64_t values, using the DispatchKernel2D function like this:
//
// template<typename T, int64_t x, int64_t y>
// struct Sum {
// static void run(T z, T w) {
// T val = x + y + z + w;
// std::cout << val << std::endl;
// }
// }
//
// template<typename T>
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
// constexpr int64_t xmin = 1;
// constexpr int64_t xmax = 3;
// constexpr int64_t ymin = 2;
// constexpr int64_t ymax = 5;
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
// }
//
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
// compile specialized versions of sum for all values of (x, y) with
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
// specialized version based on the runtime values of x and y.
// Define some helper structs in an anonymous namespace.
namespace {
// 1D dispatch: general case.
// Kernel is the function we want to dispatch to; it should take a typename and
// an int64_t as template args, and it should define a static void function
// run which takes any number of arguments of any type.
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template <
template <typename, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (curN == N) {
// The compile-time value curN is equal to the run-time value N, so we
// can dispatch to the run method of the Kernel.
Kernel<T, curN>::run(args...);
} else if (curN < N) {
// Increment curN via template recursion
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
N, args...);
}
// We shouldn't get here -- throw an error?
}
};
// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template <
template <typename, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
static void run(const int64_t N, Args... args) {
if (N == maxN) {
Kernel<T, maxN>::run(args...);
}
// We shouldn't get here -- throw an error?
}
};
// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template <
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) {
Kernel<T, curN, curM>::run(args...);
} else if (curN < N && curM < M) {
// Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make.
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} else if (curN < N) {
// Increment curN only
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM,
Args...>::run(N, M, args...);
} else if (curM < M) {
// Increment curM only
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
}
};
// 2D dispatch, specialization for curN == maxN
template <
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) {
Kernel<T, maxN, curM>::run(args...);
} else if (curM < maxM) {
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curM == maxM
template <
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) {
Kernel<T, curN, maxM>::run(args...);
} else if (curN < maxN) {
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
maxM,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curN == maxN, curM == maxM
template <
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) {
Kernel<T, maxN, maxM>::run(args...);
}
// We should not get here -- throw an error?
}
};
} // namespace
// This is the function we expect users to call to dispatch to 1D functions
template <
template <typename, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
N, args...);
}
// Maybe throw an error if we tried to dispatch outside the allowed range?
}
// This is the function we expect users to call to dispatch to 2D functions
template <
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
minN,
minM,
maxM,
minM,
Args...>::run(N, M, args...);
}
// Maybe throw an error if we tried to dispatch outside the specified range?
} |