Spaces:
Running
on
A100
Running
on
A100
namespace cg = cooperative_groups; | |
template <typename scalar_t> | |
__global__ void fp8_adamw_cuda_kernel( | |
scalar_t* __restrict__ params, scalar_t* __restrict__ grads, | |
__nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg, | |
__nv_fp8_e4m3* __restrict__ exp_avg_sq, | |
float* __restrict__ scale_exp_avg_sq, float beta1, float beta2, float lr, | |
float wd, float eps, int step, int qgroup_size, int total_elements, | |
int total_scale_elements) { | |
const int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
const int scale_idx = blockIdx.x; | |
float float_exp_avg, float_exp_avg_sq; | |
float correction1, correction2_sqrt; | |
float denom, update; | |
if (idx < total_elements) { | |
// dequantize the optimizer states | |
float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; | |
float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; | |
// calculation of optimizer.step() | |
float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; | |
float_exp_avg_sq = | |
beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; | |
correction1 = 1.0f - powf(beta1, step); | |
correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); | |
denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; | |
update = (float_exp_avg / denom) + (wd * params[idx]); | |
params[idx] = params[idx] - (lr * update); | |
} else { | |
float_exp_avg = 0.0f; | |
float_exp_avg_sq = 0.0f; | |
} | |
//// quantize the first-order and second-order momentum | |
int wid = threadIdx.x / WARPSIZE; | |
// reduction within a warp | |
__shared__ float sharedFirstMaxVal[32]; | |
__shared__ float sharedSecondMaxVal[32]; | |
cg::thread_block_tile<32> warpTile = | |
cg::tiled_partition<32>(cg::this_thread_block()); | |
float firstMaxVal = fabsf(float_exp_avg); | |
float secondMaxVal = fabsf(float_exp_avg_sq); | |
for (int i = warpTile.size() / 2; i > 0; i /= 2) { | |
float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); | |
float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); | |
firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); | |
secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); | |
// printf("First Max: %f\n", reduceFirstMaxVal); | |
} | |
int lane = warpTile.thread_rank(); | |
if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal; | |
if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal; | |
__syncthreads(); | |
// reduction within a block | |
__shared__ float shared_absmax_exp_avg; | |
__shared__ float shared_absmax_exp_avg_sq; | |
firstMaxVal = | |
(threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; | |
secondMaxVal = | |
(threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; | |
if (wid == 0) { | |
for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { | |
float reduceFirstMaxVal = | |
__shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); | |
float reduceSecondMaxVal = | |
__shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); | |
firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); | |
secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); | |
} | |
if (lane == 0) shared_absmax_exp_avg = firstMaxVal; | |
if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal; | |
} | |
__syncthreads(); | |
if (idx < total_elements) { | |
// float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); | |
float fp8MaxVal = 448; | |
shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL; | |
shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; | |
float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal; | |
float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal; | |
// quantize the optimizer states | |
__nv_fp8_e4m3 exp_avg_new = | |
static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); | |
__nv_fp8_e4m3 exp_avg_sq_new = | |
static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); | |
// __half exp_avg_new = static_cast<__half>(float_exp_avg / | |
// new_scale_exp_avg); | |
// __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / | |
// new_scale_exp_avg_sq); | |
// printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, | |
// (float)exp_avg_new * new_scale_exp_avg); | |
// store the output | |
exp_avg[idx] = exp_avg_new; | |
exp_avg_sq[idx] = exp_avg_sq_new; | |
scale_exp_avg[scale_idx] = new_scale_exp_avg; | |
scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; | |
} | |
} | |
void FP8_AdamW_cuda(torch::Tensor params, // parameter | |
torch::Tensor grads, // gradient | |
torch::Tensor exp_avg, // first order momentum | |
torch::Tensor scale_exp_avg, | |
torch::Tensor exp_avg_sq, // second order momentum | |
torch::Tensor scale_exp_avg_sq, float beta1, float beta2, | |
float lr, float wd, float eps, int step, | |
int qgroup_size) { // other parameters | |
// CUDA Blocks | |
int total_elements = params.numel(); | |
int total_scale_elements = scale_exp_avg.numel(); | |
AT_ASSERTM(qgroup_size == 128, | |
"Only Support 128 per-group quantization currently"); | |
const int block_dim = 128; // This should equal to the qgroup_size | |
int grid_dim = (total_elements + qgroup_size - 1) / block_dim; | |
AT_ASSERTM(grid_dim == scale_exp_avg.numel()); | |
AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel()); | |
const dim3 blocks(grid_dim); | |
// Execution | |
AT_DISPATCH_FLOATING_TYPES_AND2( | |
at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] { | |
fp8_adamw_cuda_kernel<scalar_t><<<blocks, block_dim>>>( | |
params.data_ptr<scalar_t>(), grads.data_ptr<scalar_t>(), | |
(__nv_fp8_e4m3*)exp_avg.data_ptr<at::Float8_e4m3fn>(), | |
scale_exp_avg.data_ptr<float>(), | |
(__nv_fp8_e4m3*)exp_avg_sq.data_ptr<at::Float8_e4m3fn>(), | |
scale_exp_avg_sq.data_ptr<float>(), beta1, beta2, lr, wd, eps, step, | |
qgroup_size, total_elements, total_scale_elements); | |
})); | |
} | |