Spaces:
Running
on
A100
Running
on
A100
File size: 6,302 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <torch/extension.h>
#define QUANT_MIN_VAL 1e-20
namespace cg = cooperative_groups;
#define WARPSIZE 32
template <typename scalar_t>
__global__ void fp8_adamw_cuda_kernel(
scalar_t* __restrict__ params, scalar_t* __restrict__ grads,
__nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg,
__nv_fp8_e4m3* __restrict__ exp_avg_sq,
float* __restrict__ scale_exp_avg_sq, float beta1, float beta2, float lr,
float wd, float eps, int step, int qgroup_size, int total_elements,
int total_scale_elements) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int scale_idx = blockIdx.x;
float float_exp_avg, float_exp_avg_sq;
float correction1, correction2_sqrt;
float denom, update;
if (idx < total_elements) {
// dequantize the optimizer states
float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
// calculation of optimizer.step()
float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
float_exp_avg_sq =
beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
correction1 = 1.0f - powf(beta1, step);
correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
update = (float_exp_avg / denom) + (wd * params[idx]);
params[idx] = params[idx] - (lr * update);
} else {
float_exp_avg = 0.0f;
float_exp_avg_sq = 0.0f;
}
//// quantize the first-order and second-order momentum
int wid = threadIdx.x / WARPSIZE;
// reduction within a warp
__shared__ float sharedFirstMaxVal[32];
__shared__ float sharedSecondMaxVal[32];
cg::thread_block_tile<32> warpTile =
cg::tiled_partition<32>(cg::this_thread_block());
float firstMaxVal = fabsf(float_exp_avg);
float secondMaxVal = fabsf(float_exp_avg_sq);
for (int i = warpTile.size() / 2; i > 0; i /= 2) {
float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
// printf("First Max: %f\n", reduceFirstMaxVal);
}
int lane = warpTile.thread_rank();
if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal;
if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal;
__syncthreads();
// reduction within a block
__shared__ float shared_absmax_exp_avg;
__shared__ float shared_absmax_exp_avg_sq;
firstMaxVal =
(threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
secondMaxVal =
(threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
if (wid == 0) {
for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
float reduceFirstMaxVal =
__shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
float reduceSecondMaxVal =
__shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
}
if (lane == 0) shared_absmax_exp_avg = firstMaxVal;
if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal;
}
__syncthreads();
if (idx < total_elements) {
// float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
float fp8MaxVal = 448;
shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL;
shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal;
float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal;
// quantize the optimizer states
__nv_fp8_e4m3 exp_avg_new =
static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
__nv_fp8_e4m3 exp_avg_sq_new =
static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
// __half exp_avg_new = static_cast<__half>(float_exp_avg /
// new_scale_exp_avg);
// __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
// new_scale_exp_avg_sq);
// printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
// (float)exp_avg_new * new_scale_exp_avg);
// store the output
exp_avg[idx] = exp_avg_new;
exp_avg_sq[idx] = exp_avg_sq_new;
scale_exp_avg[scale_idx] = new_scale_exp_avg;
scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
}
}
void FP8_AdamW_cuda(torch::Tensor params, // parameter
torch::Tensor grads, // gradient
torch::Tensor exp_avg, // first order momentum
torch::Tensor scale_exp_avg,
torch::Tensor exp_avg_sq, // second order momentum
torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
float lr, float wd, float eps, int step,
int qgroup_size) { // other parameters
// CUDA Blocks
int total_elements = params.numel();
int total_scale_elements = scale_exp_avg.numel();
AT_ASSERTM(qgroup_size == 128,
"Only Support 128 per-group quantization currently");
const int block_dim = 128; // This should equal to the qgroup_size
int grid_dim = (total_elements + qgroup_size - 1) / block_dim;
AT_ASSERTM(grid_dim == scale_exp_avg.numel());
AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel());
const dim3 blocks(grid_dim);
// Execution
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] {
fp8_adamw_cuda_kernel<scalar_t><<<blocks, block_dim>>>(
params.data_ptr<scalar_t>(), grads.data_ptr<scalar_t>(),
(__nv_fp8_e4m3*)exp_avg.data_ptr<at::Float8_e4m3fn>(),
scale_exp_avg.data_ptr<float>(),
(__nv_fp8_e4m3*)exp_avg_sq.data_ptr<at::Float8_e4m3fn>(),
scale_exp_avg_sq.data_ptr<float>(), beta1, beta2, lr, wd, eps, step,
qgroup_size, total_elements, total_scale_elements);
}));
}
|