Spaces:
Runtime error
Runtime error
| /****************************************************************************** | |
| * Copyright (c) 2011, Duane Merrill. All rights reserved. | |
| * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Redistribution and use in source and binary forms, with or without | |
| * modification, are permitted provided that the following conditions are met: | |
| * * Redistributions of source code must retain the above copyright | |
| * notice, this list of conditions and the following disclaimer. | |
| * * Redistributions in binary form must reproduce the above copyright | |
| * notice, this list of conditions and the following disclaimer in the | |
| * documentation and/or other materials provided with the distribution. | |
| * * Neither the name of the NVIDIA CORPORATION nor the | |
| * names of its contributors may be used to endorse or promote products | |
| * derived from this software without specific prior written permission. | |
| * | |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| * | |
| ******************************************************************************/ | |
| /** | |
| * \file | |
| * Simple binary operator functor types | |
| */ | |
| /****************************************************************************** | |
| * Simple functor operators | |
| ******************************************************************************/ | |
| #pragma once | |
| #include "../config.cuh" | |
| #include "../util_type.cuh" | |
| /// Optional outer namespace(s) | |
| CUB_NS_PREFIX | |
| /// CUB namespace | |
| namespace cub { | |
| /** | |
| * \addtogroup UtilModule | |
| * @{ | |
| */ | |
| /** | |
| * \brief Default equality functor | |
| */ | |
| struct Equality | |
| { | |
| /// Boolean equality operator, returns <tt>(a == b)</tt> | |
| template <typename T> | |
| __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const | |
| { | |
| return a == b; | |
| } | |
| }; | |
| /** | |
| * \brief Default inequality functor | |
| */ | |
| struct Inequality | |
| { | |
| /// Boolean inequality operator, returns <tt>(a != b)</tt> | |
| template <typename T> | |
| __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const | |
| { | |
| return a != b; | |
| } | |
| }; | |
| /** | |
| * \brief Inequality functor (wraps equality functor) | |
| */ | |
| template <typename EqualityOp> | |
| struct InequalityWrapper | |
| { | |
| /// Wrapped equality operator | |
| EqualityOp op; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ | |
| InequalityWrapper(EqualityOp op) : op(op) {} | |
| /// Boolean inequality operator, returns <tt>(a != b)</tt> | |
| template <typename T> | |
| __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) | |
| { | |
| return !op(a, b); | |
| } | |
| }; | |
| /** | |
| * \brief Default sum functor | |
| */ | |
| struct Sum | |
| { | |
| /// Boolean sum operator, returns <tt>a + b</tt> | |
| template <typename T> | |
| __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const | |
| { | |
| return a + b; | |
| } | |
| }; | |
| /** | |
| * \brief Default max functor | |
| */ | |
| struct Max | |
| { | |
| /// Boolean max operator, returns <tt>(a > b) ? a : b</tt> | |
| template <typename T> | |
| __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const | |
| { | |
| return CUB_MAX(a, b); | |
| } | |
| }; | |
| /** | |
| * \brief Arg max functor (keeps the value and offset of the first occurrence of the larger item) | |
| */ | |
| struct ArgMax | |
| { | |
| /// Boolean max operator, preferring the item having the smaller offset in case of ties | |
| template <typename T, typename OffsetT> | |
| __host__ __device__ __forceinline__ KeyValuePair<OffsetT, T> operator()( | |
| const KeyValuePair<OffsetT, T> &a, | |
| const KeyValuePair<OffsetT, T> &b) const | |
| { | |
| // Mooch BUG (device reduce argmax gk110 3.2 million random fp32) | |
| // return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; | |
| if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) | |
| return b; | |
| return a; | |
| } | |
| }; | |
| /** | |
| * \brief Default min functor | |
| */ | |
| struct Min | |
| { | |
| /// Boolean min operator, returns <tt>(a < b) ? a : b</tt> | |
| template <typename T> | |
| __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const | |
| { | |
| return CUB_MIN(a, b); | |
| } | |
| }; | |
| /** | |
| * \brief Arg min functor (keeps the value and offset of the first occurrence of the smallest item) | |
| */ | |
| struct ArgMin | |
| { | |
| /// Boolean min operator, preferring the item having the smaller offset in case of ties | |
| template <typename T, typename OffsetT> | |
| __host__ __device__ __forceinline__ KeyValuePair<OffsetT, T> operator()( | |
| const KeyValuePair<OffsetT, T> &a, | |
| const KeyValuePair<OffsetT, T> &b) const | |
| { | |
| // Mooch BUG (device reduce argmax gk110 3.2 million random fp32) | |
| // return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; | |
| if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) | |
| return b; | |
| return a; | |
| } | |
| }; | |
| /** | |
| * \brief Default cast functor | |
| */ | |
| template <typename B> | |
| struct CastOp | |
| { | |
| /// Cast operator, returns <tt>(B) a</tt> | |
| template <typename A> | |
| __host__ __device__ __forceinline__ B operator()(const A &a) const | |
| { | |
| return (B) a; | |
| } | |
| }; | |
| /** | |
| * \brief Binary operator wrapper for switching non-commutative scan arguments | |
| */ | |
| template <typename ScanOp> | |
| class SwizzleScanOp | |
| { | |
| private: | |
| /// Wrapped scan operator | |
| ScanOp scan_op; | |
| public: | |
| /// Constructor | |
| __host__ __device__ __forceinline__ | |
| SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) {} | |
| /// Switch the scan arguments | |
| template <typename T> | |
| __host__ __device__ __forceinline__ | |
| T operator()(const T &a, const T &b) | |
| { | |
| T _a(a); | |
| T _b(b); | |
| return scan_op(_b, _a); | |
| } | |
| }; | |
| /** | |
| * \brief Reduce-by-segment functor. | |
| * | |
| * Given two cub::KeyValuePair inputs \p a and \p b and a | |
| * binary associative combining operator \p <tt>f(const T &x, const T &y)</tt>, | |
| * an instance of this functor returns a cub::KeyValuePair whose \p key | |
| * field is <tt>a.key</tt> + <tt>b.key</tt>, and whose \p value field | |
| * is either b.value if b.key is non-zero, or f(a.value, b.value) otherwise. | |
| * | |
| * ReduceBySegmentOp is an associative, non-commutative binary combining operator | |
| * for input sequences of cub::KeyValuePair pairings. Such | |
| * sequences are typically used to represent a segmented set of values to be reduced | |
| * and a corresponding set of {0,1}-valued integer "head flags" demarcating the | |
| * first value of each segment. | |
| * | |
| */ | |
| template <typename ReductionOpT> ///< Binary reduction operator to apply to values | |
| struct ReduceBySegmentOp | |
| { | |
| /// Wrapped reduction operator | |
| ReductionOpT op; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ ReduceBySegmentOp() {} | |
| /// Constructor | |
| __host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) : op(op) {} | |
| /// Scan operator | |
| template <typename KeyValuePairT> ///< KeyValuePair pairing of T (value) and OffsetT (head flag) | |
| __host__ __device__ __forceinline__ KeyValuePairT operator()( | |
| const KeyValuePairT &first, ///< First partial reduction | |
| const KeyValuePairT &second) ///< Second partial reduction | |
| { | |
| KeyValuePairT retval; | |
| retval.key = first.key + second.key; | |
| retval.value = (second.key) ? | |
| second.value : // The second partial reduction spans a segment reset, so it's value aggregate becomes the running aggregate | |
| op(first.value, second.value); // The second partial reduction does not span a reset, so accumulate both into the running aggregate | |
| return retval; | |
| } | |
| }; | |
| template <typename ReductionOpT> ///< Binary reduction operator to apply to values | |
| struct ReduceByKeyOp | |
| { | |
| /// Wrapped reduction operator | |
| ReductionOpT op; | |
| /// Constructor | |
| __host__ __device__ __forceinline__ ReduceByKeyOp() {} | |
| /// Constructor | |
| __host__ __device__ __forceinline__ ReduceByKeyOp(ReductionOpT op) : op(op) {} | |
| /// Scan operator | |
| template <typename KeyValuePairT> | |
| __host__ __device__ __forceinline__ KeyValuePairT operator()( | |
| const KeyValuePairT &first, ///< First partial reduction | |
| const KeyValuePairT &second) ///< Second partial reduction | |
| { | |
| KeyValuePairT retval = second; | |
| if (first.key == second.key) | |
| retval.value = op(first.value, retval.value); | |
| return retval; | |
| } | |
| }; | |
| /** @} */ // end group UtilModule | |
| } // CUB namespace | |
| CUB_NS_POSTFIX // Optional outer namespace(s) | |