|
/*! |
|
************************************************************************************************** |
|
* InternImage |
|
* Copyright (c) 2022 OpenGVLab |
|
* Licensed under The MIT License [see LICENSE for details] |
|
************************************************************************************************** |
|
* Modified from |
|
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 |
|
************************************************************************************************** |
|
*/ |
|
|
|
#include "cuda/dcnv3_im2col_cuda.cuh" |
|
#include <vector> |
|
|
|
#include <ATen/ATen.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <cuda.h> |
|
#include <cuda_runtime.h> |
|
#include <torch/torch.h> |
|
|
|
at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, |
|
const at::Tensor &mask, const int kernel_h, |
|
const int kernel_w, const int stride_h, |
|
const int stride_w, const int pad_h, |
|
const int pad_w, const int dilation_h, |
|
const int dilation_w, const int group, |
|
const int group_channels, |
|
const float offset_scale, const int im2col_step) { |
|
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); |
|
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); |
|
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); |
|
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); |
|
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); |
|
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); |
|
|
|
const int batch = input.size(0); |
|
const int height_in = input.size(1); |
|
const int width_in = input.size(2); |
|
const int channels = input.size(3); |
|
const int height_out = |
|
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + |
|
1; |
|
const int width_out = |
|
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + |
|
1; |
|
const int im2col_step_ = std::min(batch, im2col_step); |
|
|
|
AT_ASSERTM(batch % im2col_step_ == 0, |
|
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_); |
|
AT_ASSERTM( |
|
channels == (group * group_channels), |
|
"Input channels and group times group channels wont match: (%d vs %d).", |
|
channels, group * group_channels); |
|
|
|
auto output = |
|
at::zeros({batch, height_out, width_out, group * group_channels}, |
|
input.options()); |
|
|
|
const int batch_n = im2col_step_; |
|
auto output_n = output.view({batch / batch_n, batch_n, height_out, |
|
width_out, group * group_channels}); |
|
auto per_input_size = height_in * width_in * group * group_channels; |
|
auto per_offset_size = |
|
height_out * width_out * group * kernel_h * kernel_w * 2; |
|
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; |
|
for (int n = 0; n < batch / im2col_step_; ++n) { |
|
auto columns = output_n.select(0, n); |
|
// AT_DISPATCH_FLOATING_TYPES( |
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
input.type(), "ms_deform_attn_forward_cuda", ([&] { |
|
dcnv3_im2col_cuda( |
|
at::cuda::getCurrentCUDAStream(), |
|
input.data<scalar_t>() + n * im2col_step_ * per_input_size, |
|
offset.data<scalar_t>() + |
|
n * im2col_step_ * per_offset_size, |
|
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size, |
|
columns.data<scalar_t>(), kernel_h, kernel_w, stride_h, |
|
stride_w, pad_h, pad_w, dilation_h, dilation_w, group, |
|
group_channels, batch_n, height_in, width_in, height_out, |
|
width_out, offset_scale); |
|
})); |
|
} |
|
|
|
return output; |
|
} |
|
|
|
std::vector<at::Tensor> |
|
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, |
|
const at::Tensor &mask, const int kernel_h, |
|
const int kernel_w, const int stride_h, const int stride_w, |
|
const int pad_h, const int pad_w, const int dilation_h, |
|
const int dilation_w, const int group, |
|
const int group_channels, const float offset_scale, |
|
const at::Tensor &grad_output, const int im2col_step) { |
|
|
|
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); |
|
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); |
|
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); |
|
AT_ASSERTM(grad_output.is_contiguous(), |
|
"grad_output tensor has to be contiguous"); |
|
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); |
|
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); |
|
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); |
|
AT_ASSERTM(grad_output.type().is_cuda(), |
|
"grad_output must be a CUDA tensor"); |
|
|
|
const int batch = input.size(0); |
|
const int height_in = input.size(1); |
|
const int width_in = input.size(2); |
|
const int channels = input.size(3); |
|
const int height_out = |
|
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + |
|
1; |
|
const int width_out = |
|
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + |
|
1; |
|
const int im2col_step_ = std::min(batch, im2col_step); |
|
|
|
AT_ASSERTM(batch % im2col_step_ == 0, |
|
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_); |
|
AT_ASSERTM( |
|
channels == (group * group_channels), |
|
"Input channels and group times group channels wont match: (%d vs %d).", |
|
channels, group * group_channels); |
|
|
|
auto dtype = input.dtype(); |
|
if (dtype == at::kHalf) { |
|
dtype = at::kFloat; |
|
} |
|
|
|
auto grad_input = at::zeros_like(input, dtype); |
|
auto grad_offset = at::zeros_like(offset, dtype); |
|
auto grad_mask = at::zeros_like(mask, dtype); |
|
|
|
const int batch_n = im2col_step_; |
|
auto per_input_size = height_in * width_in * group * group_channels; |
|
auto per_offset_size = |
|
height_out * width_out * group * kernel_h * kernel_w * 2; |
|
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; |
|
auto grad_output_n = |
|
grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, |
|
group, group_channels}); |
|
|
|
for (int n = 0; n < batch / im2col_step_; ++n) { |
|
auto grad_output_g = grad_output_n.select(0, n); |
|
// AT_DISPATCH_FLOATING_TYPES( |
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
input.type(), "ms_deform_attn_backward_cuda", ([&] { |
|
dcnv3_col2im_cuda( |
|
at::cuda::getCurrentCUDAStream(), |
|
grad_output_g.data<scalar_t>(), |
|
input.data<scalar_t>() + n * im2col_step_ * per_input_size, |
|
offset.data<scalar_t>() + |
|
n * im2col_step_ * per_offset_size, |
|
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size, |
|
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, |
|
dilation_h, dilation_w, group, group_channels, batch_n, |
|
height_in, width_in, height_out, width_out, offset_scale, |
|
grad_input.data<opmath_t>() + |
|
n * im2col_step_ * per_input_size, |
|
grad_offset.data<opmath_t>() + |
|
n * im2col_step_ * per_offset_size, |
|
grad_mask.data<opmath_t>() + |
|
n * im2col_step_ * per_mask_size); |
|
})); |
|
} |
|
|
|
if (input.dtype() == torch::kHalf) { |
|
return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf), |
|
grad_mask.to(torch::kHalf)}; |
|
} else { |
|
return {grad_input, grad_offset, grad_mask}; |
|
} |
|
} |
|
|