Spaces:
Sleeping
Sleeping
| /* coding=utf-8 | |
| * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace multihead_attn { | |
| namespace fused_softmax { | |
| namespace scaled_masked_softmax { | |
| int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ | |
| return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); | |
| } | |
| torch::Tensor fwd_cuda( | |
| torch::Tensor const& input, | |
| torch::Tensor const& mask, | |
| float scale_factor) | |
| { | |
| // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] | |
| const int batches = input.size(0); | |
| const int pad_batches = mask.size(0); | |
| const int attn_heads = input.size(1); | |
| const int query_seq_len = input.size(2); | |
| const int key_seq_len = input.size(3); | |
| TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); | |
| TORCH_INTERNAL_ASSERT(query_seq_len > 1); | |
| TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); | |
| TORCH_INTERNAL_ASSERT(mask.size(1) == 1); | |
| TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); | |
| TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); | |
| // Output | |
| auto act_options = input.options().requires_grad(false); | |
| torch::Tensor softmax_results = | |
| torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); | |
| // Softmax Intermediate Result Ptr | |
| void* input_ptr = static_cast<void*>(input.data_ptr()); | |
| void* mask_ptr = static_cast<void*>(mask.data_ptr()); | |
| void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); | |
| DISPATCH_HALF_AND_BFLOAT( | |
| input.scalar_type(), | |
| "dispatch_scaled_masked_softmax_forward", | |
| dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>( | |
| reinterpret_cast<scalar_t*>(softmax_results_ptr), | |
| reinterpret_cast<const scalar_t*>(input_ptr), | |
| reinterpret_cast<const uint8_t*>(mask_ptr), | |
| scale_factor, | |
| query_seq_len, | |
| key_seq_len, | |
| batches, | |
| attn_heads, | |
| pad_batches | |
| ); | |
| ); | |
| return softmax_results; | |
| } | |
| torch::Tensor bwd_cuda( | |
| torch::Tensor const& output_grads_, | |
| torch::Tensor const& softmax_results_, | |
| float scale_factor) { | |
| auto output_grads = output_grads_.contiguous(); | |
| auto softmax_results = softmax_results_.contiguous(); | |
| //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] | |
| const int batches = output_grads.size(0); | |
| const int attn_heads = output_grads.size(1); | |
| const int query_seq_len = output_grads.size(2); | |
| const int key_seq_len = output_grads.size(3); | |
| auto act_options = output_grads.options().requires_grad(false); | |
| torch::Tensor input_grads = | |
| torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); | |
| void* input_grads_ptr = static_cast<void*>(input_grads.data_ptr()); | |
| void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); | |
| //Softmax Grad | |
| DISPATCH_HALF_AND_BFLOAT( | |
| output_grads_.scalar_type(), | |
| "dispatch_scaled_masked_softmax_backward", | |
| dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>( | |
| reinterpret_cast<scalar_t*>(input_grads_ptr), | |
| reinterpret_cast<scalar_t*>(output_grads_ptr), | |
| reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), | |
| scale_factor, | |
| query_seq_len, | |
| key_seq_len, | |
| batches, | |
| attn_heads | |
| ); | |
| ); | |
| return input_grads; | |
| } | |
| } | |
| } | |
| } | |