Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,720 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#define MINK_H
#include "index_utils.cuh"
// A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored
// in arrays passed to the constructor.
//
// The implementation is generic; it can be used for any key type that supports
// the < operator, and can be used with any value type.
//
// Example usage:
//
// float keys[K];
// int values[K];
// MinK<float, int> mink(keys, values, K);
// for (...) {
// // Produce some key and value from somewhere
// mink.add(key, value);
// }
// mink.sort();
//
// Now keys and values store the smallest K keys seen so far and the values
// associated to these keys:
//
// for (int k = 0; k < K; ++k) {
// float key_k = keys[k];
// int value_k = values[k];
// }
template <typename key_t, typename value_t>
class MinK {
public:
// Constructor.
//
// Arguments:
// keys: Array in which to store keys
// values: Array in which to store values
// K: How many values to keep track of
__device__ MinK(key_t* keys, value_t* vals, int K)
: keys(keys), vals(vals), K(K), _size(0) {}
// Try to add a new key and associated value to the data structure. If the key
// is one of the smallest K seen so far then it will be kept; otherwise it
// it will not be kept.
//
// This takes O(1) operations if the new key is not kept, or if the structure
// currently contains fewer than K elements. Otherwise this takes O(K) time.
//
// Arguments:
// key: The key to add
// val: The value associated to the key
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
keys[_size] = key;
vals[_size] = val;
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
keys[max_idx] = key;
vals[max_idx] = val;
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = keys[k];
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
// Get the number of items currently stored in the structure.
// This takes O(1) time.
__device__ __forceinline__ int size() {
return _size;
}
// Sort the items stored in the structure using bubble sort.
// This takes O(K^2) time.
__device__ __forceinline__ void sort() {
for (int i = 0; i < _size - 1; ++i) {
for (int j = 0; j < _size - i - 1; ++j) {
if (keys[j + 1] < keys[j]) {
key_t key = keys[j];
value_t val = vals[j];
keys[j] = keys[j + 1];
vals[j] = vals[j + 1];
keys[j + 1] = key;
vals[j + 1] = val;
}
}
}
}
private:
key_t* keys;
value_t* vals;
int K;
int _size;
key_t max_key;
int max_idx;
};
// This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for
// fast access.
//
// This has the same API as RegisterMinK, but doesn't support sorting.
// We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays
// into registers.
template <typename key_t, typename value_t, int K>
class RegisterMinK {
public:
__device__ RegisterMinK(key_t* keys, value_t* vals)
: keys(keys), vals(vals), _size(0) {}
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
RegisterIndexUtils<key_t, K>::set(keys, max_idx, key);
RegisterIndexUtils<value_t, K>::set(vals, max_idx, val);
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = RegisterIndexUtils<key_t, K>::get(keys, k);
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
__device__ __forceinline__ int size() {
return _size;
}
private:
key_t* keys;
value_t* vals;
int _size;
key_t max_key;
int max_idx;
}; |