|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <ATen/ATen.h> |
|
|
|
#include "fused_rotary_positional_embedding.h" |
|
#include "type_shim.h" |
|
|
|
namespace fused_rope { |
|
|
|
torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs, |
|
const bool transpose_output) { |
|
|
|
|
|
|
|
|
|
|
|
const int s = input.size(0); |
|
const int b = input.size(1); |
|
const int h = input.size(2); |
|
const int d = input.size(3); |
|
|
|
const int stride_s = input.stride(0); |
|
const int stride_b = input.stride(1); |
|
const int stride_h = input.stride(2); |
|
const int stride_d = input.stride(3); |
|
|
|
|
|
const int d2 = freqs.size(3); |
|
|
|
|
|
auto act_options = input.options().requires_grad(false); |
|
torch::Tensor output; |
|
if (transpose_output) { |
|
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); |
|
} else { |
|
output = torch::empty({s, b, h, d}, act_options); |
|
} |
|
|
|
const int o_stride_s = output.stride(0); |
|
const int o_stride_b = output.stride(1); |
|
const int o_stride_h = output.stride(2); |
|
const int o_stride_d = output.stride(3); |
|
|
|
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
|
input.scalar_type(), 0, "dispatch_fused_rope_forward", |
|
dispatch_fused_rope_forward( |
|
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, |
|
o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), |
|
freqs.data_ptr<float>(), output.data_ptr<scalar_t_0>());); |
|
return output; |
|
} |
|
|
|
torch::Tensor bwd_cuda(const torch::Tensor &output_grads, |
|
const torch::Tensor &freqs, |
|
const bool transpose_output) { |
|
|
|
|
|
|
|
|
|
|
|
const int s = output_grads.size(0); |
|
const int b = output_grads.size(1); |
|
const int h = output_grads.size(2); |
|
const int d = output_grads.size(3); |
|
|
|
const int stride_s = output_grads.stride(0); |
|
const int stride_b = output_grads.stride(1); |
|
const int stride_h = output_grads.stride(2); |
|
const int stride_d = output_grads.stride(3); |
|
|
|
|
|
const int d2 = freqs.size(3); |
|
|
|
auto act_options = output_grads.options().requires_grad(false); |
|
torch::Tensor input_grads; |
|
if (transpose_output) { |
|
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); |
|
} else { |
|
input_grads = torch::empty({s, b, h, d}, act_options); |
|
} |
|
const int o_stride_s = input_grads.stride(0); |
|
const int o_stride_b = input_grads.stride(1); |
|
const int o_stride_h = input_grads.stride(2); |
|
const int o_stride_d = input_grads.stride(3); |
|
|
|
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
|
output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", |
|
dispatch_fused_rope_backward( |
|
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, |
|
o_stride_b, o_stride_h, o_stride_d, |
|
output_grads.data_ptr<scalar_t_0>(), freqs.data_ptr<float>(), |
|
input_grads.data_ptr<scalar_t_0>());); |
|
return input_grads; |
|
} |
|
|
|
#define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \ |
|
switch (TYPE1) { \ |
|
case at::ScalarType::Float: { \ |
|
using scalar_t_0 = float; \ |
|
switch (TYPE2) { \ |
|
case at::ScalarType::Float: { \ |
|
using scalar_t_1 = float; \ |
|
__VA_ARGS__; \ |
|
break; \ |
|
} \ |
|
default: \ |
|
TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ |
|
"' with '", toString(TYPE2), "'"); \ |
|
} \ |
|
break; \ |
|
} \ |
|
case at::ScalarType::Half: { \ |
|
using scalar_t_0 = at::Half; \ |
|
switch (TYPE2) { \ |
|
case at::ScalarType::Float: { \ |
|
using scalar_t_1 = float; \ |
|
__VA_ARGS__; \ |
|
break; \ |
|
} \ |
|
case at::ScalarType::Half: { \ |
|
using scalar_t_1 = at::Half; \ |
|
__VA_ARGS__; \ |
|
break; \ |
|
} \ |
|
default: \ |
|
TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ |
|
"' with '", toString(TYPE2), "'"); \ |
|
} \ |
|
break; \ |
|
} \ |
|
case at::ScalarType::BFloat16: { \ |
|
using scalar_t_0 = at::BFloat16; \ |
|
switch (TYPE2) { \ |
|
case at::ScalarType::Float: { \ |
|
using scalar_t_1 = float; \ |
|
__VA_ARGS__; \ |
|
break; \ |
|
} \ |
|
case at::ScalarType::BFloat16: { \ |
|
using scalar_t_1 = at::BFloat16; \ |
|
__VA_ARGS__; \ |
|
break; \ |
|
} \ |
|
default: \ |
|
TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ |
|
"' with '", toString(TYPE2), "'"); \ |
|
} \ |
|
break; \ |
|
} \ |
|
default: \ |
|
TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ |
|
"' with '", toString(TYPE2), "'"); \ |
|
} |
|
|
|
torch::Tensor fwd_cached_cuda(const torch::Tensor &input, |
|
const torch::Tensor &cos, |
|
const torch::Tensor &sin, |
|
const bool transpose_output) { |
|
|
|
|
|
|
|
|
|
|
|
const int s = input.size(0); |
|
const int b = input.size(1); |
|
const int h = input.size(2); |
|
const int d = input.size(3); |
|
|
|
const int stride_s = input.stride(0); |
|
const int stride_b = input.stride(1); |
|
const int stride_h = input.stride(2); |
|
const int stride_d = input.stride(3); |
|
|
|
|
|
const int d2 = cos.size(3); |
|
|
|
|
|
auto act_options = input.options().requires_grad(false); |
|
torch::Tensor output; |
|
if (transpose_output) { |
|
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); |
|
} else { |
|
output = torch::empty({s, b, h, d}, act_options); |
|
} |
|
|
|
const int o_stride_s = output.stride(0); |
|
const int o_stride_b = output.stride(1); |
|
const int o_stride_h = output.stride(2); |
|
const int o_stride_d = output.stride(3); |
|
|
|
DISPATCH_FUSED_ROPE_TYPES( |
|
input.scalar_type(), cos.scalar_type(), |
|
"dispatch_fused_rope_cached_forward", |
|
dispatch_fused_rope_cached_forward( |
|
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, |
|
o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), |
|
cos.data_ptr<scalar_t_1>(), sin.data_ptr<scalar_t_1>(), |
|
output.data_ptr<scalar_t_0>());); |
|
return output; |
|
} |
|
|
|
torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads, |
|
const torch::Tensor &cos, |
|
const torch::Tensor &sin, |
|
const bool transpose_output) { |
|
|
|
|
|
|
|
|
|
|
|
const int s = output_grads.size(0); |
|
const int b = output_grads.size(1); |
|
const int h = output_grads.size(2); |
|
const int d = output_grads.size(3); |
|
|
|
const int stride_s = output_grads.stride(0); |
|
const int stride_b = output_grads.stride(1); |
|
const int stride_h = output_grads.stride(2); |
|
const int stride_d = output_grads.stride(3); |
|
|
|
|
|
const int d2 = cos.size(3); |
|
|
|
auto act_options = output_grads.options().requires_grad(false); |
|
torch::Tensor input_grads; |
|
if (transpose_output) { |
|
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); |
|
} else { |
|
input_grads = torch::empty({s, b, h, d}, act_options); |
|
} |
|
const int o_stride_s = input_grads.stride(0); |
|
const int o_stride_b = input_grads.stride(1); |
|
const int o_stride_h = input_grads.stride(2); |
|
const int o_stride_d = input_grads.stride(3); |
|
|
|
DISPATCH_FUSED_ROPE_TYPES( |
|
output_grads.scalar_type(), cos.scalar_type(), |
|
"dispatch_fused_rope_cached_backward", |
|
dispatch_fused_rope_cached_backward( |
|
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, |
|
o_stride_b, o_stride_h, o_stride_d, |
|
output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_1>(), |
|
sin.data_ptr<scalar_t_1>(), input_grads.data_ptr<scalar_t_0>());); |
|
return input_grads; |
|
} |
|
|
|
torch::Tensor fwd_thd_cuda(const torch::Tensor &input, |
|
const torch::Tensor &cu_seqlens, |
|
const torch::Tensor &freqs) { |
|
|
|
|
|
|
|
|
|
const int t = input.size(0); |
|
const int h = input.size(1); |
|
const int d = input.size(2); |
|
|
|
const int stride_t = input.stride(0); |
|
const int stride_h = input.stride(1); |
|
const int stride_d = input.stride(2); |
|
|
|
const int b = cu_seqlens.size(0) - 1; |
|
|
|
const int max_s = freqs.size(0); |
|
const int d2 = freqs.size(3); |
|
|
|
|
|
auto act_options = input.options().requires_grad(false); |
|
auto output = torch::empty({t, h, d}, act_options); |
|
|
|
const int o_stride_t = output.stride(0); |
|
const int o_stride_h = output.stride(1); |
|
const int o_stride_d = output.stride(2); |
|
|
|
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
|
input.scalar_type(), 0, "dispatch_fused_rope_thd_forward", |
|
dispatch_fused_rope_thd_forward( |
|
max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, |
|
o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), |
|
cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(), |
|
output.data_ptr<scalar_t_0>());); |
|
return output; |
|
} |
|
|
|
torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads, |
|
const torch::Tensor &cu_seqlens, |
|
const torch::Tensor &freqs) { |
|
|
|
|
|
|
|
|
|
const int t = output_grads.size(0); |
|
const int h = output_grads.size(1); |
|
const int d = output_grads.size(2); |
|
|
|
const int stride_t = output_grads.stride(0); |
|
const int stride_h = output_grads.stride(1); |
|
const int stride_d = output_grads.stride(2); |
|
|
|
const int b = cu_seqlens.size(0) - 1; |
|
|
|
const int max_s = freqs.size(0); |
|
const int d2 = freqs.size(3); |
|
|
|
auto act_options = output_grads.options().requires_grad(false); |
|
auto input_grads = torch::empty({t, h, d}, act_options); |
|
const int o_stride_t = input_grads.stride(0); |
|
const int o_stride_h = input_grads.stride(1); |
|
const int o_stride_d = input_grads.stride(2); |
|
|
|
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
|
output_grads.scalar_type(), 0, "dispatch_fused_rope_thd_backward", |
|
dispatch_fused_rope_thd_backward( |
|
max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, |
|
o_stride_h, o_stride_d, output_grads.data_ptr<scalar_t_0>(), |
|
cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(), |
|
input_grads.data_ptr<scalar_t_0>());); |
|
return input_grads; |
|
} |
|
|
|
} |
|
|