|
#include <ATen/ATen.h> |
|
#include <ATen/AccumulateType.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <ATen/cuda/Exceptions.h> |
|
#include "multi_tensor_apply.cuh" |
|
#include "compat.h" |
|
|
|
#include <assert.h> |
|
#include <cuda_runtime.h> |
|
|
|
#define BLOCK_SIZE 512 |
|
#define ILP 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<int N, typename T_grad, typename T_weight> |
|
struct SGDFunctor |
|
{ |
|
__device__ __forceinline__ void operator()( |
|
int chunk_size, |
|
volatile int* noop_gmem, |
|
TensorListMetadata<N>& tl, |
|
float wd, |
|
float momentum, |
|
float dampening, |
|
float lr, |
|
bool nesterov, |
|
bool first_run, |
|
bool wd_after_momentum, |
|
float scale) |
|
{ |
|
|
|
if (*noop_gmem) return; |
|
|
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.sizes[tensor_loc]; |
|
|
|
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc]; |
|
grad_in += chunk_idx*chunk_size; |
|
|
|
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc]; |
|
weight_in += chunk_idx*chunk_size; |
|
|
|
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc]; |
|
mom_in += chunk_idx*chunk_size; |
|
|
|
at::Half *model_weights_out = nullptr; |
|
if(N == 4) |
|
{ |
|
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; |
|
model_weights_out += chunk_idx*chunk_size; |
|
} |
|
|
|
n -= chunk_idx*chunk_size; |
|
|
|
|
|
float incoming_grads[ILP]; |
|
float incoming_weights[ILP]; |
|
float incoming_moms[ILP]; |
|
for(int i_start = 0; |
|
i_start < n && i_start < chunk_size; |
|
i_start += blockDim.x*ILP) |
|
{ |
|
#pragma unroll |
|
for(int ii = 0; ii < ILP; ii++) |
|
{ |
|
incoming_grads[ii] = 0; |
|
incoming_weights[ii] = 0; |
|
incoming_moms[ii] = 0; |
|
int i = i_start + threadIdx.x + ii*blockDim.x; |
|
if(i < n && i < chunk_size) |
|
{ |
|
incoming_grads[ii] = static_cast<float>(grad_in[i])*scale; |
|
incoming_weights[ii] = static_cast<float>(weight_in[i]); |
|
incoming_moms[ii] = static_cast<float>(mom_in[i]); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
for(int ii = 0; ii < ILP; ii++) |
|
{ |
|
int i = i_start + threadIdx.x + ii*blockDim.x; |
|
if(i < n && i < chunk_size) |
|
{ |
|
|
|
if(wd != 0.f && !wd_after_momentum) |
|
incoming_grads[ii] += wd * incoming_weights[ii]; |
|
|
|
if(momentum != 0.f) |
|
{ |
|
if(!first_run) |
|
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; |
|
else |
|
incoming_moms[ii] = incoming_grads[ii]; |
|
|
|
if(nesterov) |
|
incoming_grads[ii] += momentum * incoming_moms[ii]; |
|
else |
|
incoming_grads[ii] = incoming_moms[ii]; |
|
} |
|
|
|
|
|
if(wd != 0.f && wd_after_momentum) |
|
incoming_grads[ii] += wd * incoming_weights[ii]; |
|
|
|
|
|
weight_in[i] += (-lr * incoming_grads[ii]); |
|
|
|
|
|
if(N == 4) |
|
model_weights_out[i] = static_cast<at::Half>(weight_in[i]); |
|
|
|
|
|
if(momentum != 0.f) |
|
mom_in[i] = incoming_moms[ii]; |
|
} |
|
} |
|
} |
|
} |
|
}; |
|
|
|
void multi_tensor_sgd_cuda( |
|
int chunk_size, |
|
at::Tensor noop_flag, |
|
std::vector<std::vector<at::Tensor>> tensor_lists, |
|
float wd, |
|
float momentum, |
|
float dampening, |
|
float lr, |
|
bool nesterov, |
|
bool first_run, |
|
bool wd_after_momentum, |
|
float scale) |
|
{ |
|
auto num_tensors = tensor_lists.size(); |
|
auto grad_type = tensor_lists[0][0].scalar_type(); |
|
auto weight_type = tensor_lists[1][0].scalar_type(); |
|
|
|
if(num_tensors == 4) |
|
for(int i = 0; i < tensor_lists[3].size(); i++) |
|
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, |
|
"Additional output tensors should always be fp16."); |
|
|
|
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(grad_type == at::ScalarType::Half && |
|
weight_type == at::ScalarType::Half && |
|
num_tensors == 3) |
|
{ |
|
multi_tensor_apply<3>( |
|
BLOCK_SIZE, |
|
chunk_size, |
|
noop_flag, |
|
tensor_lists, |
|
SGDFunctor<3, at::Half, at::Half>(), |
|
wd, |
|
momentum, |
|
dampening, |
|
lr, |
|
nesterov, |
|
first_run, |
|
wd_after_momentum, |
|
scale); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else if(grad_type == at::ScalarType::Float && |
|
weight_type == at::ScalarType::Float && |
|
num_tensors == 3) |
|
{ |
|
multi_tensor_apply<3>( |
|
BLOCK_SIZE, |
|
chunk_size, |
|
noop_flag, |
|
tensor_lists, |
|
SGDFunctor<3, float, float>(), |
|
wd, |
|
momentum, |
|
dampening, |
|
lr, |
|
nesterov, |
|
first_run, |
|
wd_after_momentum, |
|
scale); |
|
} |
|
|
|
else if(grad_type == at::ScalarType::Half && |
|
weight_type == at::ScalarType::Float && |
|
num_tensors == 4) |
|
{ |
|
multi_tensor_apply<4>( |
|
BLOCK_SIZE, |
|
chunk_size, |
|
noop_flag, |
|
tensor_lists, |
|
SGDFunctor<4, at::Half, float>(), |
|
wd, |
|
momentum, |
|
dampening, |
|
lr, |
|
nesterov, |
|
first_run, |
|
wd_after_momentum, |
|
scale); |
|
} |
|
|
|
else if(grad_type == at::ScalarType::Float && |
|
weight_type == at::ScalarType::Float && |
|
num_tensors == 4) |
|
{ |
|
multi_tensor_apply<4>( |
|
BLOCK_SIZE, |
|
chunk_size, |
|
noop_flag, |
|
tensor_lists, |
|
SGDFunctor<4, float, float>(), |
|
wd, |
|
momentum, |
|
dampening, |
|
lr, |
|
nesterov, |
|
first_run, |
|
wd_after_momentum, |
|
scale); |
|
} |
|
else |
|
{ |
|
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", |
|
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); |
|
} |
|
|
|
AT_CUDA_CHECK(cudaGetLastError()); |
|
} |
|
|