|
#include <ATen/ATen.h> |
|
#include <ATen/cuda/Exceptions.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, |
|
int* growth_tracker, |
|
int* hysteresis_tracker, |
|
const float* found_inf, |
|
double growth_factor, |
|
double backoff_factor, |
|
int growth_interval, |
|
int hysteresis) |
|
{ |
|
if (*found_inf > 0) { |
|
*hysteresis_tracker -= 1; |
|
|
|
|
|
if (*hysteresis_tracker > 0) { |
|
*growth_tracker = 0; |
|
return; |
|
} |
|
} |
|
|
|
if (*found_inf) { |
|
*current_scale = (*current_scale)*backoff_factor; |
|
*growth_tracker = 0; |
|
} else { |
|
|
|
|
|
auto successful = (*growth_tracker) + 1; |
|
if (successful == growth_interval) { |
|
auto new_scale = static_cast<float>((*current_scale)*growth_factor); |
|
|
|
if (isfinite(new_scale)) { |
|
*current_scale = new_scale; |
|
} |
|
*growth_tracker = 0; |
|
} else { |
|
*growth_tracker = successful; |
|
} |
|
} |
|
|
|
|
|
if (*found_inf <= 0) { |
|
*hysteresis_tracker = hysteresis; |
|
} |
|
} |
|
|
|
at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, |
|
at::Tensor growth_tracker, |
|
at::Tensor hysteresis_tracker, |
|
at::Tensor found_inf, |
|
const double growth_factor, |
|
const double backoff_factor, |
|
const int64_t growth_interval, |
|
const int hysteresis) |
|
{ |
|
update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( |
|
current_scale.mutable_data_ptr<float>(), |
|
growth_tracker.mutable_data_ptr<int>(), |
|
hysteresis_tracker.mutable_data_ptr<int>(), |
|
found_inf.const_data_ptr<float>(), |
|
growth_factor, |
|
backoff_factor, |
|
growth_interval, |
|
hysteresis); |
|
|
|
AT_CUDA_CHECK(cudaGetLastError()); |
|
|
|
return current_scale; |
|
} |
|
|