File size: 10,395 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

// This file provides utilities for dispatching to specialized versions of
// functions. This is especially useful for CUDA kernels, since specializing
// them to particular input sizes can often allow the compiler to unroll loops
// and place arrays into registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
//
// template<typename T, int64_t x>
// struct SquareOffset {
//   static void run(T y) {
//     T val = x * x + y;
//     std::cout << val << std::endl;
//   }
// }
//
// This function takes one compile-time argument x, and one run-time argument y.
// We might want to compile specialized versions of this for x=0, x=1, etc and
// then dispatch to the correct one based on the runtime value of x.
// One simple way to achieve this is with a lookup table:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
//   if (x == 0) {
//     SquareOffset<T, 0>::run(y);
//   } else if (x == 1) {
//     SquareOffset<T, 1>::run(y);
//   } else if (x == 2) {
//     SquareOffset<T, 2>::run(y);
//   }
// }
//
// This function takes both x and y as run-time arguments, and dispatches to
// different specialized versions of SquareOffset based on the run-time value
// of x. This works, but it's tedious and error-prone. If we want to change the
// set of x values for which we provide compile-time specializations, then we
// will need to do a lot of tedius editing of the dispatch function. Also, if we
// want to provide compile-time specializations for another function other than
// SquareOffset, we will need to duplicate the entire lookup table.
//
// To solve these problems, we can use the DispatchKernel1D function provided by
// this file instead:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
//     constexpr int64_t xmin = 0;
//     constexpr int64_t xmax = 2;
//     DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
// }
//
// DispatchKernel1D uses template metaprogramming to compile specialized
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
// then dispatches to the correct one based on the run-time value of x. If we
// want to change the range of x values for which SquareOffset is specialized
// at compile-time, then all we have to do is change the values of the
// compile-time constants xmin and xmax.
//
// This file also allows us to similarly dispatch functions that depend on two
// compile-time int64_t values, using the DispatchKernel2D function like this:
//
// template<typename T, int64_t x, int64_t y>
// struct Sum {
//   static void run(T z, T w) {
//     T val = x + y + z + w;
//     std::cout << val << std::endl;
//   }
// }
//
// template<typename T>
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
//   constexpr int64_t xmin = 1;
//   constexpr int64_t xmax = 3;
//   constexpr int64_t ymin = 2;
//   constexpr int64_t ymax = 5;
//   DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
// }
//
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
// compile specialized versions of sum for all values of (x, y) with
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
// specialized version based on the runtime values of x and y.

// Define some helper structs in an anonymous namespace.
namespace {

// 1D dispatch: general case.
// Kernel is the function we want to dispatch to; it should take a typename and
// an int64_t as template args, and it should define a static void function
// run which takes any number of arguments of any type.
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template <
    template <typename, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    int64_t curN,
    typename... Args>
struct DispatchKernelHelper1D {
  static void run(const int64_t N, Args... args) {
    if (curN == N) {
      // The compile-time value curN is equal to the run-time value N, so we
      // can dispatch to the run method of the Kernel.
      Kernel<T, curN>::run(args...);
    } else if (curN < N) {
      // Increment curN via template recursion
      DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
          N, args...);
    }
    // We shouldn't get here -- throw an error?
  }
};

// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template <
    template <typename, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    typename... Args>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
  static void run(const int64_t N, Args... args) {
    if (N == maxN) {
      Kernel<T, maxN>::run(args...);
    }
    // We shouldn't get here -- throw an error?
  }
};

// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template <
    template <typename, int64_t, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    int64_t curN,
    int64_t minM,
    int64_t maxM,
    int64_t curM,
    typename... Args>
struct DispatchKernelHelper2D {
  static void run(const int64_t N, const int64_t M, Args... args) {
    if (curN == N && curM == M) {
      Kernel<T, curN, curM>::run(args...);
    } else if (curN < N && curM < M) {
      // Increment both curN and curM. This isn't strictly necessary; we could
      // just increment one or the other at each step. But this helps to cut
      // on the number of recursive calls we make.
      DispatchKernelHelper2D<
          Kernel,
          T,
          minN,
          maxN,
          curN + 1,
          minM,
          maxM,
          curM + 1,
          Args...>::run(N, M, args...);
    } else if (curN < N) {
      // Increment curN only
      DispatchKernelHelper2D<
          Kernel,
          T,
          minN,
          maxN,
          curN + 1,
          minM,
          maxM,
          curM,
          Args...>::run(N, M, args...);
    } else if (curM < M) {
      // Increment curM only
      DispatchKernelHelper2D<
          Kernel,
          T,
          minN,
          maxN,
          curN,
          minM,
          maxM,
          curM + 1,
          Args...>::run(N, M, args...);
    }
  }
};

// 2D dispatch, specialization for curN == maxN
template <
    template <typename, int64_t, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    int64_t minM,
    int64_t maxM,
    int64_t curM,
    typename... Args>
struct DispatchKernelHelper2D<
    Kernel,
    T,
    minN,
    maxN,
    maxN,
    minM,
    maxM,
    curM,
    Args...> {
  static void run(const int64_t N, const int64_t M, Args... args) {
    if (maxN == N && curM == M) {
      Kernel<T, maxN, curM>::run(args...);
    } else if (curM < maxM) {
      DispatchKernelHelper2D<
          Kernel,
          T,
          minN,
          maxN,
          maxN,
          minM,
          maxM,
          curM + 1,
          Args...>::run(N, M, args...);
    }
    // We should not get here -- throw an error?
  }
};

// 2D dispatch, specialization for curM == maxM
template <
    template <typename, int64_t, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    int64_t curN,
    int64_t minM,
    int64_t maxM,
    typename... Args>
struct DispatchKernelHelper2D<
    Kernel,
    T,
    minN,
    maxN,
    curN,
    minM,
    maxM,
    maxM,
    Args...> {
  static void run(const int64_t N, const int64_t M, Args... args) {
    if (curN == N && maxM == M) {
      Kernel<T, curN, maxM>::run(args...);
    } else if (curN < maxN) {
      DispatchKernelHelper2D<
          Kernel,
          T,
          minN,
          maxN,
          curN + 1,
          minM,
          maxM,
          maxM,
          Args...>::run(N, M, args...);
    }
    // We should not get here -- throw an error?
  }
};

// 2D dispatch, specialization for curN == maxN, curM == maxM
template <
    template <typename, int64_t, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    int64_t minM,
    int64_t maxM,
    typename... Args>
struct DispatchKernelHelper2D<
    Kernel,
    T,
    minN,
    maxN,
    maxN,
    minM,
    maxM,
    maxM,
    Args...> {
  static void run(const int64_t N, const int64_t M, Args... args) {
    if (maxN == N && maxM == M) {
      Kernel<T, maxN, maxM>::run(args...);
    }
    // We should not get here -- throw an error?
  }
};

} // namespace

// This is the function we expect users to call to dispatch to 1D functions
template <
    template <typename, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    typename... Args>
void DispatchKernel1D(const int64_t N, Args... args) {
  if (minN <= N && N <= maxN) {
    // Kick off the template recursion by calling the Helper with curN = minN
    DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
        N, args...);
  }
  // Maybe throw an error if we tried to dispatch outside the allowed range?
}

// This is the function we expect users to call to dispatch to 2D functions
template <
    template <typename, int64_t, int64_t>
    class Kernel,
    typename T,
    int64_t minN,
    int64_t maxN,
    int64_t minM,
    int64_t maxM,
    typename... Args>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
  if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
    // Kick off the template recursion by calling the Helper with curN = minN
    // and curM = minM
    DispatchKernelHelper2D<
        Kernel,
        T,
        minN,
        maxN,
        minN,
        minM,
        maxM,
        minM,
        Args...>::run(N, M, args...);
  }
  // Maybe throw an error if we tried to dispatch outside the specified range?
}