File size: 4,720 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once
#define MINK_H

#include "index_utils.cuh"

// A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored
// in arrays passed to the constructor.
//
// The implementation is generic; it can be used for any key type that supports
// the < operator, and can be used with any value type.
//
// Example usage:
//
// float keys[K];
// int values[K];
// MinK<float, int> mink(keys, values, K);
// for (...) {
//   // Produce some key and value from somewhere
//   mink.add(key, value);
// }
// mink.sort();
//
// Now keys and values store the smallest K keys seen so far and the values
// associated to these keys:
//
// for (int k = 0; k < K; ++k) {
//   float key_k = keys[k];
//   int value_k = values[k];
// }
template <typename key_t, typename value_t>
class MinK {
 public:
  // Constructor.
  //
  // Arguments:
  //   keys: Array in which to store keys
  //   values: Array in which to store values
  //   K: How many values to keep track of
  __device__ MinK(key_t* keys, value_t* vals, int K)
      : keys(keys), vals(vals), K(K), _size(0) {}

  // Try to add a new key and associated value to the data structure. If the key
  // is one of the smallest K seen so far then it will be kept; otherwise it
  // it will not be kept.
  //
  // This takes O(1) operations if the new key is not kept, or if the structure
  // currently contains fewer than K elements. Otherwise this takes O(K) time.
  //
  // Arguments:
  //   key: The key to add
  //   val: The value associated to the key
  __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
    if (_size < K) {
      keys[_size] = key;
      vals[_size] = val;
      if (_size == 0 || key > max_key) {
        max_key = key;
        max_idx = _size;
      }
      _size++;
    } else if (key < max_key) {
      keys[max_idx] = key;
      vals[max_idx] = val;
      max_key = key;
      for (int k = 0; k < K; ++k) {
        key_t cur_key = keys[k];
        if (cur_key > max_key) {
          max_key = cur_key;
          max_idx = k;
        }
      }
    }
  }

  // Get the number of items currently stored in the structure.
  // This takes O(1) time.
  __device__ __forceinline__ int size() {
    return _size;
  }

  // Sort the items stored in the structure using bubble sort.
  // This takes O(K^2) time.
  __device__ __forceinline__ void sort() {
    for (int i = 0; i < _size - 1; ++i) {
      for (int j = 0; j < _size - i - 1; ++j) {
        if (keys[j + 1] < keys[j]) {
          key_t key = keys[j];
          value_t val = vals[j];
          keys[j] = keys[j + 1];
          vals[j] = vals[j + 1];
          keys[j + 1] = key;
          vals[j + 1] = val;
        }
      }
    }
  }

 private:
  key_t* keys;
  value_t* vals;
  int K;
  int _size;
  key_t max_key;
  int max_idx;
};

// This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for
// fast access.
//
// This has the same API as RegisterMinK, but doesn't support sorting.
// We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays
// into registers.
template <typename key_t, typename value_t, int K>
class RegisterMinK {
 public:
  __device__ RegisterMinK(key_t* keys, value_t* vals)
      : keys(keys), vals(vals), _size(0) {}

  __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
    if (_size < K) {
      RegisterIndexUtils<key_t, K>::set(keys, _size, key);
      RegisterIndexUtils<value_t, K>::set(vals, _size, val);
      if (_size == 0 || key > max_key) {
        max_key = key;
        max_idx = _size;
      }
      _size++;
    } else if (key < max_key) {
      RegisterIndexUtils<key_t, K>::set(keys, max_idx, key);
      RegisterIndexUtils<value_t, K>::set(vals, max_idx, val);
      max_key = key;
      for (int k = 0; k < K; ++k) {
        key_t cur_key = RegisterIndexUtils<key_t, K>::get(keys, k);
        if (cur_key > max_key) {
          max_key = cur_key;
          max_idx = k;
        }
      }
    }
  }

  __device__ __forceinline__ int size() {
    return _size;
  }

 private:
  key_t* keys;
  value_t* vals;
  int _size;
  key_t max_key;
  int max_idx;
};