| | #pragma once |
| |
|
| | #include "vectorization.cuh" |
| |
|
| | #include <c10/core/ScalarType.h> |
| | #include <cmath> |
| |
|
| | #ifndef USE_ROCM |
| | #include <c10/util/Float8_e4m3fn.h> |
| | using FP8_TYPE = c10::Float8_e4m3fn; |
| | C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = |
| | std::numeric_limits<FP8_TYPE>::max(); |
| | #else |
| | #include "amd/hip_float8.h" |
| | #include <c10/util/Float8_e4m3fnuz.h> |
| | using FP8_TYPE = c10::Float8_e4m3fnuz; |
| | |
| | |
| | constexpr auto FP8_E4M3_MAX = 224.0f; |
| | #endif |
| | constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value; |
| |
|
| | namespace vllm { |
| |
|
| | __device__ __forceinline__ float atomicMaxFloat(float *addr, float value) { |
| | float old; |
| | old = (value >= 0) |
| | ? __int_as_float(atomicMax((int *)addr, __float_as_int(value))) |
| | : __uint_as_float( |
| | atomicMin((unsigned int *)addr, __float_as_uint(value))); |
| |
|
| | return old; |
| | } |
| |
|
| | template <bool is_scale_inverted> |
| | __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, |
| | float const scale) { |
| | float x = 0.0f; |
| | if constexpr (is_scale_inverted) { |
| | x = val * scale; |
| | } else { |
| | x = val / scale; |
| | } |
| |
|
| | float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); |
| | #ifndef USE_ROCM |
| | return static_cast<c10::Float8_e4m3fn>(r); |
| | #else |
| | |
| | return c10::Float8_e4m3fnuz(hip_fp8(r).data, |
| | c10::Float8_e4m3fnuz::from_bits()); |
| | #endif |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | template <typename scalar_t> |
| | __global__ void segmented_max_reduction(float *__restrict__ scale, |
| | const scalar_t *__restrict__ input, |
| | int64_t num_elems) { |
| | __shared__ float cache[1024]; |
| | int64_t i = blockDim.x * blockIdx.x + threadIdx.x; |
| |
|
| | |
| | |
| | scalar_t tmp = 0.0; |
| | while (i < num_elems) { |
| | float x = static_cast<float>(input[i]); |
| | tmp = max(tmp, fabs(x)); |
| | i += blockDim.x * gridDim.x; |
| | } |
| | cache[threadIdx.x] = tmp; |
| |
|
| | __syncthreads(); |
| |
|
| | |
| | int ib = blockDim.x / 2; |
| | while (ib != 0) { |
| | if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { |
| | cache[threadIdx.x] = cache[threadIdx.x + ib]; |
| | } |
| | __syncthreads(); |
| | ib /= 2; |
| | } |
| | |
| | |
| | if (threadIdx.x == 0) { |
| | atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); |
| | } |
| | } |
| |
|
| | template <typename scalar_t> |
| | __device__ float thread_max_vec(scalar_t const *__restrict__ input, |
| | int64_t const num_elems, int const tid, |
| | int const step) { |
| | |
| | vec4_t<scalar_t> const *vectorized_in = |
| | reinterpret_cast<vec4_t<scalar_t> const *>(input); |
| |
|
| | int64_t const num_vec_elems = num_elems >> 2; |
| | float absmax_val = 0.0f; |
| |
|
| | #pragma unroll 4 |
| | for (int64_t i = tid; i < num_vec_elems; i += step) { |
| | vec4_t<scalar_t> in_vec = vectorized_in[i]; |
| | absmax_val = max(absmax_val, fabs(in_vec.x)); |
| | absmax_val = max(absmax_val, fabs(in_vec.y)); |
| | absmax_val = max(absmax_val, fabs(in_vec.z)); |
| | absmax_val = max(absmax_val, fabs(in_vec.w)); |
| | } |
| |
|
| | |
| | for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { |
| | absmax_val = max(absmax_val, fabs(input[i])); |
| | } |
| |
|
| | return absmax_val; |
| | } |
| |
|
| | template <typename scalar_t, bool is_scale_inverted> |
| | __device__ void scaled_fp8_conversion_vec(FP8_TYPE *__restrict__ out, |
| | scalar_t const *__restrict__ input, |
| | float const scale, |
| | int64_t const num_elems, |
| | int const tid, int const step) { |
| | using float8x4_t = q8x4_t<FP8_TYPE>; |
| | |
| | auto const *vectorized_in = reinterpret_cast<vec4_t<scalar_t> const *>(input); |
| | auto *vectorized_out = reinterpret_cast<float8x4_t *>(out); |
| |
|
| | int64_t const num_vec_elems = num_elems >> 2; |
| |
|
| | #pragma unroll 4 |
| | for (int64_t i = tid; i < num_vec_elems; i += step) { |
| | vec4_t<scalar_t> in_vec = vectorized_in[i]; |
| | float8x4_t out_vec; |
| |
|
| | out_vec.x = scaled_fp8_conversion<is_scale_inverted>( |
| | static_cast<float>(in_vec.x), scale); |
| | out_vec.y = scaled_fp8_conversion<is_scale_inverted>( |
| | static_cast<float>(in_vec.y), scale); |
| | out_vec.z = scaled_fp8_conversion<is_scale_inverted>( |
| | static_cast<float>(in_vec.z), scale); |
| | out_vec.w = scaled_fp8_conversion<is_scale_inverted>( |
| | static_cast<float>(in_vec.w), scale); |
| | vectorized_out[i] = out_vec; |
| | } |
| |
|
| | |
| | for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { |
| | out[i] = scaled_fp8_conversion<is_scale_inverted>( |
| | static_cast<float>(input[i]), scale); |
| | } |
| | } |
| |
|
| | } |
| |
|