Add torch2.6 support for ms_deform_attn_cuda (#94)
This commit is contained in:
@@ -15,11 +15,24 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <torch/version.h>
|
||||||
|
|
||||||
|
// Check PyTorch version and define appropriate macros
|
||||||
|
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
|
||||||
|
// PyTorch 2.x and above
|
||||||
|
#define GET_TENSOR_TYPE(x) x.scalar_type()
|
||||||
|
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
|
||||||
|
#else
|
||||||
|
// PyTorch 1.x
|
||||||
|
#define GET_TENSOR_TYPE(x) x.type()
|
||||||
|
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace groundingdino {
|
namespace groundingdino {
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward(
|
at::Tensor ms_deform_attn_cuda_forward(
|
||||||
const at::Tensor &value,
|
const at::Tensor &value,
|
||||||
const at::Tensor &spatial_shapes,
|
const at::Tensor &spatial_shapes,
|
||||||
const at::Tensor &level_start_index,
|
const at::Tensor &level_start_index,
|
||||||
const at::Tensor &sampling_loc,
|
const at::Tensor &sampling_loc,
|
||||||
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
|
|
||||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
|
||||||
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
|
||||||
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
|
||||||
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
|
||||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
|
||||||
|
|
||||||
const int batch = value.size(0);
|
const int batch = value.size(0);
|
||||||
const int spatial_size = value.size(1);
|
const int spatial_size = value.size(1);
|
||||||
@@ -51,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
const int im2col_step_ = std::min(batch, im2col_step);
|
const int im2col_step_ = std::min(batch, im2col_step);
|
||||||
|
|
||||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||||
|
|
||||||
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
||||||
|
|
||||||
const int batch_n = im2col_step_;
|
const int batch_n = im2col_step_;
|
||||||
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -82,7 +95,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||||
const at::Tensor &value,
|
const at::Tensor &value,
|
||||||
const at::Tensor &spatial_shapes,
|
const at::Tensor &spatial_shapes,
|
||||||
const at::Tensor &level_start_index,
|
const at::Tensor &level_start_index,
|
||||||
const at::Tensor &sampling_loc,
|
const at::Tensor &sampling_loc,
|
||||||
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
||||||
|
|
||||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
|
||||||
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
|
||||||
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
|
||||||
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
|
||||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
|
||||||
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
|
||||||
|
|
||||||
const int batch = value.size(0);
|
const int batch = value.size(0);
|
||||||
const int spatial_size = value.size(1);
|
const int spatial_size = value.size(1);
|
||||||
@@ -128,11 +141,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||||
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
||||||
|
|
||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
@@ -153,4 +166,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace groundingdino
|
} // namespace groundingdino
|
||||||
|
Reference in New Issue
Block a user