From 2a1967dae6c7f2718c34dd06168abbe7bf9ac2d3 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Sun, 4 Feb 2024 16:57:15 +0800 Subject: [PATCH 01/26] add ms_deform_attn --- mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp | 42 ++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp new file mode 100644 index 0000000000..68ad24d0b4 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp @@ -0,0 +1,42 @@ + +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor ms_deform_attn_impl_backward_npu(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunction, + value, spatial_shapes, level_start_index, sampling_loc, attn_weight); +} + +void ms_deform_attn_impl_backward_npu( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step) { + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunctionGrad, + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, + grad_output, grad_value, grad_sampling_loc, grad_attn_weight); +} + +Tensor ms_deform_attn_impl_backward_npu(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step); +REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_impl_backward_npu); + +void ms_deform_attn_impl_backward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step); +REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, + ms_deform_attn_impl_backward_npu); From 846f5b521fdb6d65e6533a975efc7563fca9a4b9 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Mon, 5 Feb 2024 11:29:22 +0800 Subject: [PATCH 02/26] add points_in_box --- mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp | 42 ------------------- .../csrc/pytorch/npu/points_in_box_npu.cpp | 18 ++++++++ 2 files changed, 18 insertions(+), 42 deletions(-) delete mode 100644 mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp deleted file mode 100644 index 68ad24d0b4..0000000000 --- a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn.cpp +++ /dev/null @@ -1,42 +0,0 @@ - -#include "pytorch_npu_helper.hpp" - -using namespace NPU_NAME_SPACE; -using namespace std; - -Tensor ms_deform_attn_impl_backward_npu(const Tensor &value, - const Tensor &spatial_shapes, - const Tensor &level_start_index, - const Tensor &sampling_loc, - const Tensor &attn_weight, - const int im2col_step) { - EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunction, - value, spatial_shapes, level_start_index, sampling_loc, attn_weight); -} - -void ms_deform_attn_impl_backward_npu( - const Tensor &value, const Tensor &spatial_shapes, - const Tensor &level_start_index, const Tensor &sampling_loc, - const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, - Tensor &grad_sampling_loc, Tensor &grad_attn_weight, - const int im2col_step) { - EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunctionGrad, - value, spatial_shapes, level_start_index, sampling_loc, attn_weight, - grad_output, grad_value, grad_sampling_loc, grad_attn_weight); -} - -Tensor ms_deform_attn_impl_backward_npu(const Tensor &value, - const Tensor &spatial_shapes, - const Tensor &level_start_index, - const Tensor &sampling_loc, - const Tensor &attn_weight, - const int im2col_step); -REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_impl_backward_npu); - -void ms_deform_attn_impl_backward(const Tensor &value, const Tensor &spatial_shapes, - const Tensor &level_start_index, const Tensor &sampling_loc, - const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, - Tensor &grad_sampling_loc, Tensor &grad_attn_weight, - const int im2col_step); -REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, - ms_deform_attn_impl_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp b/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp new file mode 100644 index 0000000000..63f0998c69 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp @@ -0,0 +1,18 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void points_in_boxes_part_forward_impl_npu(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + c10::SmallVector output_size = {pts.size(0), pts.size(1)}; + auto boxes_trans = boxes.transpose(1, 2).contiguous(); + EXEC_NPU_CMD(aclnnPointsInBox, boxes_trans, pts, box_idx_of_points); +} +void points_in_boxes_part_forward_impl(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); +REGISTER_NPU_IMPL(points_in_boxes_part_forward_impl, points_in_boxes_part_forward_impl_npu); From 98393a3cc50be499a40be7e0d991170a49ef3256 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Thu, 29 Feb 2024 17:29:26 +0800 Subject: [PATCH 03/26] fix bug. --- mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp b/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp index f282afeed3..6b8f08635a 100644 --- a/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp @@ -12,7 +12,7 @@ void points_in_polygons_npu(const Tensor points, Tensor polygons, Tensor output, "The batch of polygons tensor must be less than MAX_POLYGONS_BATCH"); at::Tensor trans_polygons = polygons.transpose(0, 1); OpCommand cmd; - at::Tensor new_trans_polygons = NpuUtils::format_contiguous(trans_polygons); + at::Tensor new_trans_polygons = trans_polygons.contiguous(); cmd.Name("PointsInPolygons") .Input(points, (string) "points") .Input(new_trans_polygons, (string) "polygons") From 56cda4609515a20b71a379377e82ca749ac09511 Mon Sep 17 00:00:00 2001 From: huaweiZJX <125643694+huaweiZJX@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:24:28 +0800 Subject: [PATCH 04/26] add multi npu op --- docs/zh_cn/understand_mmcv/ops.md | 4 +- .../pytorch/npu/furthest_point_sample_npu.cpp | 19 +++++++++ .../furthest_point_sampling_with_dist_npu.cpp | 20 ++++++++++ .../ops/csrc/pytorch/npu/nms3d_normal_npu.cpp | 21 ++++++++++ mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp | 26 +++++++++++++ mmcv/ops/furthest_point_sample.py | 8 +++- tests/test_ops/test_furthest_point_sample.py | 39 +++++++++++++------ tests/test_ops/test_iou3d.py | 14 +++++-- tests/test_ops/test_roiaware_pool3d.py | 30 ++++++++------ 9 files changed, 152 insertions(+), 29 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 7f4d7ea63b..d11fc4aa3d 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -22,8 +22,8 @@ MMCV 提供了检测、分割等任务中常用的算子 | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | √ | | | | DynamicScatter | | √ | √ | | | -| FurthestPointSample | | √ | | | | -| FurthestPointSampleWithDist | | √ | | | | +| FurthestPointSample | | √ | | | √ | +| FurthestPointSampleWithDist | | √ | | | √ | | FusedBiasLeakyrelu | | √ | | | √ | | GatherPoints | | √ | | | √ | | GroupPoints | | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp new file mode 100644 index 0000000000..512f00b2f8 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp @@ -0,0 +1,19 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void furthest_point_sampling_forward_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, + int b, int n, int m) { + TORCH_CHECK( + (points_tensor.sizes()[1] >= m), + "the num of sampled points needs to be smaller than total num of points."); + at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous(); + at::Tensor nearest_dist = temp_tensor.contiguous(); + EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m, idx_tensor); +} + +void furthest_point_sampling_forward_impl(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, + int b, int n, int m); + +REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl, furthest_point_sampling_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp new file mode 100644 index 0000000000..ec05fc6d7a --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp @@ -0,0 +1,20 @@ +#include "pytorch_npu_helper.hpp" +using namespace NPU_NAME_SPACE; +using namespace std; + +void furthest_point_sampling_with_dist_npu(Tensor points_tensor, + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m) { + auto output_size = {b, m}; + at::Tensor result = at::empty(output_size, points_tensor.options().dtype(at::kInt)); + EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, m, result); +} + +void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor, + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m); + +REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl, + furthest_point_sampling_with_dist_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp new file mode 100644 index 0000000000..923e9d8c90 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp @@ -0,0 +1,21 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; + +void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh) { + int32_t box_num = boxes.size(0); + int32_t data_align = 16; + int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + at::Tensor mask = at::empty({ box_num, mask_num }, boxes.options().dtype(at::kShort)); + EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask); + + keep = at::zeros({ box_num }, mask.options()); + keep_num = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); +} + +void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh); + +REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp new file mode 100644 index 0000000000..f0196a441d --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp @@ -0,0 +1,26 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +constexpr int32_t BOX_DIM = 7; + +void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh) +{ + TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)"); + int32_t box_num = boxes.size(0); + int32_t data_align = 16; + int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); + EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask); + + keep = at::zeros({box_num}, mask.options()); + keep_num = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); +} + +void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh); + +REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu); diff --git a/mmcv/ops/furthest_point_sample.py b/mmcv/ops/furthest_point_sample.py index 22b1a3048d..409f0dc320 100644 --- a/mmcv/ops/furthest_point_sample.py +++ b/mmcv/ops/furthest_point_sample.py @@ -27,8 +27,12 @@ def forward(ctx, points_xyz: torch.Tensor, assert points_xyz.is_contiguous() B, N = points_xyz.size()[:2] - output = torch.cuda.IntTensor(B, num_points) - temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + if points_xyz.device.type == 'npu': + output = torch.IntTensor(B, num_points).npu() + temp = torch.FloatTensor(B, N).fill_(1e10).npu() + else: + output = torch.cuda.IntTensor(B, num_points) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) ext_module.furthest_point_sampling_forward( points_xyz, diff --git a/tests/test_ops/test_furthest_point_sample.py b/tests/test_ops/test_furthest_point_sample.py index 7e61e64a91..d03a12c997 100644 --- a/tests/test_ops/test_furthest_point_sample.py +++ b/tests/test_ops/test_furthest_point_sample.py @@ -3,11 +3,20 @@ import torch from mmcv.ops import furthest_point_sample, furthest_point_sample_with_dist +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_fps(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_fps(device): xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681], [-0.8070, 2.4137, -0.5845], [-1.0001, 2.1982, -0.5859], @@ -15,16 +24,24 @@ def test_fps(): [[-1.0696, 3.0758, -0.1899], [-0.2559, 3.5521, -0.1402], [0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205], - [-0.0518, 3.7251, -0.3950]]]).cuda() + [-0.0518, 3.7251, -0.3950]]]).to(device) idx = furthest_point_sample(xyz, 3) - expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda() + expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device) assert torch.all(idx == expected_idx) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_fps_with_dist(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_fps_with_dist(device): xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681], [-0.8070, 2.4137, -0.5845], [-1.0001, 2.1982, -0.5859], @@ -32,9 +49,9 @@ def test_fps_with_dist(): [[-1.0696, 3.0758, -0.1899], [-0.2559, 3.5521, -0.1402], [0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205], - [-0.0518, 3.7251, -0.3950]]]).cuda() + [-0.0518, 3.7251, -0.3950]]]).to(device) - expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda() + expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).to(device) xyz_square_dist = ((xyz.unsqueeze(dim=1) - xyz.unsqueeze(dim=2))**2).sum(-1) idx = furthest_point_sample_with_dist(xyz_square_dist, 3) @@ -44,7 +61,7 @@ def test_fps_with_dist(): fps_idx = np.load('tests/data/for_3d_ops/fps_idx.npy') features_for_fps_distance = np.load( 'tests/data/for_3d_ops/features_for_fps_distance.npy') - expected_idx = torch.from_numpy(fps_idx).cuda() + expected_idx = torch.from_numpy(fps_idx).to(device) features_for_fps_distance = torch.from_numpy( features_for_fps_distance).cuda() diff --git a/tests/test_ops/test_iou3d.py b/tests/test_ops/test_iou3d.py index 6bb8c1ccce..27a09eb361 100644 --- a/tests/test_ops/test_iou3d.py +++ b/tests/test_ops/test_iou3d.py @@ -4,7 +4,7 @@ import torch from mmcv.ops import boxes_iou3d, boxes_overlap_bev, nms3d, nms3d_normal -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE @pytest.mark.parametrize('device', [ @@ -77,7 +77,11 @@ def test_boxes_iou3d(device): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_nms3d(device): # test for 5 boxes @@ -116,7 +120,11 @@ def test_nms3d(device): pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_nms3d_normal(device): # test for 5 boxes diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py index 5391e924db..02d99b2a65 100644 --- a/tests/test_ops/test_roiaware_pool3d.py +++ b/tests/test_ops/test_roiaware_pool3d.py @@ -5,7 +5,7 @@ from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu, points_in_boxes_part) -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE @pytest.mark.parametrize('device', [ @@ -56,38 +56,46 @@ def test_RoIAwarePool3d(device, dtype): torch.tensor(49.750, dtype=dtype).to(device), 1e-3) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_points_in_boxes_part(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_points_in_boxes_part(device): boxes = torch.tensor( [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]], [[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], - dtype=torch.float32).cuda( - ) # boxes (b, t, 7) with bottom center in lidar coordinate + dtype=torch.float32).to( + device) # boxes (b, t, 7) with bottom center in lidar coordinate pts = torch.tensor( [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], [4.7, 3.5, -12.2]], [[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]], - dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate + dtype=torch.float32).to(device) # points (b, m, 3) in lidar coordinate point_indices = points_in_boxes_part(points=pts, boxes=boxes) expected_point_indices = torch.tensor( [[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]], - dtype=torch.int32).cuda() + dtype=torch.int32).to(device) assert point_indices.shape == torch.Size([2, 8]) assert (point_indices == expected_point_indices).all() boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]], - dtype=torch.float32).cuda() # 30 degrees + dtype=torch.float32).to(device) # 30 degrees pts = torch.tensor( [[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0], [-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]], - dtype=torch.float32).cuda() + dtype=torch.float32).to(device) point_indices = points_in_boxes_part(points=pts, boxes=boxes) expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]], - dtype=torch.int32).cuda() + dtype=torch.int32).to(device) assert (point_indices == expected_point_indices).all() From 14f6bdc06e53dc116bc688bb723ade0f25caba68 Mon Sep 17 00:00:00 2001 From: huaweiZJX <125643694+huaweiZJX@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:24:28 +0800 Subject: [PATCH 05/26] add multi npu op. --- .../pytorch/npu/furthest_point_sample_npu.cpp | 18 ++++++++----- .../furthest_point_sampling_with_dist_npu.cpp | 20 +++++++------- .../csrc/pytorch/npu/ms_deform_attn_npu.cpp | 16 +++++------- .../ops/csrc/pytorch/npu/nms3d_normal_npu.cpp | 14 ++++++---- mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp | 26 +++++++++---------- .../csrc/pytorch/npu/points_in_box_npu.cpp | 15 ++++++----- 6 files changed, 59 insertions(+), 50 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp index 512f00b2f8..65f8b42228 100644 --- a/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp @@ -3,17 +3,21 @@ using namespace NPU_NAME_SPACE; using namespace std; -void furthest_point_sampling_forward_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, - int b, int n, int m) { +void furthest_point_sampling_forward_npu(Tensor points_tensor, + Tensor temp_tensor, Tensor idx_tensor, + int b, int n, int m) { TORCH_CHECK( (points_tensor.sizes()[1] >= m), - "the num of sampled points needs to be smaller than total num of points."); + "the num of sampled points should smaller than total num of points."); at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous(); at::Tensor nearest_dist = temp_tensor.contiguous(); - EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m, idx_tensor); + EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m, + idx_tensor); } -void furthest_point_sampling_forward_impl(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, - int b, int n, int m); +void furthest_point_sampling_forward_impl(Tensor points_tensor, + Tensor temp_tensor, Tensor idx_tensor, + int b, int n, int m); -REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl, furthest_point_sampling_forward_npu); +REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl, + furthest_point_sampling_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp index ec05fc6d7a..364d3bfa9a 100644 --- a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp @@ -3,18 +3,20 @@ using namespace NPU_NAME_SPACE; using namespace std; void furthest_point_sampling_with_dist_npu(Tensor points_tensor, - Tensor temp_tensor, - Tensor idx_tensor, int b, int n, - int m) { - auto output_size = {b, m}; - at::Tensor result = at::empty(output_size, points_tensor.options().dtype(at::kInt)); - EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, m, result); + Tensor temp_tensor, + Tensor idx_tensor, int b, int n, + int m) { + auto output_size = {b, m}; + at::Tensor result = + at::empty(output_size, points_tensor.options().dtype(at::kInt)); + EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, + m, result); } -void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor, +void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor, Tensor temp_tensor, - Tensor idx_tensor, int b, int n, - int m); + Tensor idx_tensor, int b, + int n, int m); REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl, furthest_point_sampling_with_dist_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp index da6f291048..7e943ca12f 100644 --- a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp @@ -80,16 +80,14 @@ void ms_deform_attn_impl_backward( const Tensor &value, const Tensor &spatial_shapes, const Tensor &level_start_index, const Tensor &sampling_loc, const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, - Tensor &grad_sampling_loc, Tensor &grad_attn_weight, - const int im2col_step); + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step); -void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shapes, - const Tensor &level_start_index, - const Tensor &sampling_loc, - const Tensor &attn_weight, - const Tensor &grad_output, Tensor &grad_value, - Tensor &grad_sampling_loc, - Tensor &grad_attn_weight, const int im2col_step) { +void ms_deform_attn_backward_npu( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, + const int im2col_step) { check_support(value, attn_weight); at::Tensor value_fp32 = value; at::Tensor spatial_shapes_int32 = spatial_shapes; diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp index 923e9d8c90..5d812fe047 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp @@ -3,19 +3,23 @@ using namespace NPU_NAME_SPACE; void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep, - Tensor &keep_num, float nms_overlap_thresh) { + Tensor &keep_num, + float nms_overlap_thresh) { int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; - at::Tensor mask = at::empty({ box_num, mask_num }, boxes.options().dtype(at::kShort)); + at::Tensor mask = + at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask); - keep = at::zeros({ box_num }, mask.options()); + keep = at::zeros({box_num}, mask.options()); keep_num = at::zeros(1, mask.options()); EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); } void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, float nms_overlap_thresh); + Tensor &keep_num, + float nms_overlap_thresh); -REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu); +REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, + iou3d_nms3d_normal_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp index f0196a441d..13fe6db860 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp @@ -5,19 +5,19 @@ using namespace std; constexpr int32_t BOX_DIM = 7; -void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, - Tensor &keep_num, float nms_overlap_thresh) -{ - TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)"); - int32_t box_num = boxes.size(0); - int32_t data_align = 16; - int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; - at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask); - - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); +void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num, + float nms_overlap_thresh) { + TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), + "Input boxes shape should be (N, 7)"); + int32_t box_num = boxes.size(0); + int32_t data_align = 16; + int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + at::Tensor mask = + at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); + EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask); + keep = at::zeros({box_num}, mask.options()); + keep_num = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); } void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, diff --git a/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp b/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp index 63f0998c69..70ccf0f6ae 100644 --- a/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/points_in_box_npu.cpp @@ -4,15 +4,16 @@ using namespace NPU_NAME_SPACE; using namespace std; void points_in_boxes_part_forward_impl_npu(int batch_size, int boxes_num, - int pts_num, const Tensor boxes, - const Tensor pts, - Tensor box_idx_of_points) { - c10::SmallVector output_size = {pts.size(0), pts.size(1)}; - auto boxes_trans = boxes.transpose(1, 2).contiguous(); - EXEC_NPU_CMD(aclnnPointsInBox, boxes_trans, pts, box_idx_of_points); + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + c10::SmallVector output_size = {pts.size(0), pts.size(1)}; + auto boxes_trans = boxes.transpose(1, 2).contiguous(); + EXEC_NPU_CMD(aclnnPointsInBox, boxes_trans, pts, box_idx_of_points); } void points_in_boxes_part_forward_impl(int batch_size, int boxes_num, int pts_num, const Tensor boxes, const Tensor pts, Tensor box_idx_of_points); -REGISTER_NPU_IMPL(points_in_boxes_part_forward_impl, points_in_boxes_part_forward_impl_npu); +REGISTER_NPU_IMPL(points_in_boxes_part_forward_impl, + points_in_boxes_part_forward_impl_npu); From bfffaebc75be5138601cd93388f6e78c633f3375 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Mon, 13 May 2024 10:05:24 +0800 Subject: [PATCH 06/26] add constraints of pointinpolygon --- mmcv/ops/points_in_polygons.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmcv/ops/points_in_polygons.py b/mmcv/ops/points_in_polygons.py index e54b5a896d..25b641d961 100644 --- a/mmcv/ops/points_in_polygons.py +++ b/mmcv/ops/points_in_polygons.py @@ -19,6 +19,9 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor: polygons (torch.Tensor): It has shape (M, 8), indicating (x1, y1, x2, y2, x3, y3, x4, y4). M means the number of ground truth polygons. + constraints: The absolute value of input range from e-10 to + e10 on NPU. + Returns: torch.Tensor: Return the result with the shape of (B, M), From abf8ca754b0b0036ec3904748251f8ec4a5d32a7 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Tue, 28 May 2024 15:35:49 +0800 Subject: [PATCH 07/26] fix roi_pool bug. --- mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp | 57 +++++++++++++++------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index bf0eb18d2b..2b3af2575c 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -11,19 +11,37 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, int64_t pooled_channel = 1; at::Tensor roi_actual_num = at::empty_like(rois, rois.options().dtype(at::kInt)); - OpCommand cmd; - cmd.Name("RoiPoolingWithArgMax") - .Input(input) - .Input(rois) - .Input(roi_actual_num) - .Output(output) - .Output(argmax) - .Attr("pooled_h", pooled_height_64) - .Attr("pooled_w", pooled_width_64) - .Attr("spatial_scale_h", spatial_scale) - .Attr("spatial_scale_w", spatial_scale) - .Attr("pool_channel", pooled_channel) - .Run(); + if (input.sizes()[1] % 16 == 0) { + OpCommand cmd; + cmd.Name("RoiPoolingWithArgMax") + .Input(input) + .Input(rois) + .Input(roi_actual_num) + .Output(output) + .Output(argmax) + .Attr("pooled_h", pooled_height_64) + .Attr("pooled_w", pooled_width_64) + .Attr("spatial_scale_h", spatial_scale) + .Attr("spatial_scale_w", spatial_scale) + .Attr("pool_channel", pooled_channel) + .Run(); + + } else { + OpCommand cmd; + cmd.Name("RoiPoolingWithArgMax") + .Input(input) + .Input(rois) + .Input(roi_actual_num) + .Output(output) + .Output(argmax) + .Attr("pooled_h", pooled_height_64) + .Attr("pooled_w", pooled_width_64) + .Attr("spatial_scale_h", spatial_scale) + .Attr("spatial_scale_w", spatial_scale) + .Attr("pool_channel", pooled_channel) + .Attr("_exclude_engines", (string) "AiCore") + .Run(); + } } void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, @@ -32,23 +50,28 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, int64_t pooled_height_64 = pooled_height; int64_t pooled_width_64 = pooled_width; int64_t pooled_channel = 1; + at::Tensor argmax_trans = argmax.transpose(1, 2).transpose(2, 3); + at::Tensor grad_output_trans = grad_output.transpose(1, 2).transpose(2, 3); at::Tensor roi_actual_num = at::empty_like(rois, rois.options().dtype(at::kInt)); - at::Tensor x = at::ones_like(grad_input); + at::Tensor x = at::ones_like(grad_input).transpose(1, 2).transpose(2, 3); + at::Tensor y = at::zeros_like(x); OpCommand cmd; cmd.Name("RoiPoolingGradWithArgMax") - .Input(grad_output) + .Input(grad_output_trans) .Input(x) .Input(rois) .Input(roi_actual_num) - .Input(argmax) - .Output(grad_input) + .Input(argmax_trans) + .Output(y) .Attr("pooled_h", pooled_height_64) .Attr("pooled_w", pooled_width_64) .Attr("spatial_scale_h", spatial_scale) .Attr("spatial_scale_w", spatial_scale) .Attr("pool_channel", pooled_channel) .Run(); + at::Tensor res = NpuUtils::format_contiguous(result); + grad_input.copy_(res); } void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, From 737b5b4f08ddeee2a76d40911b1d5dbfcc82f0f2 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Thu, 30 May 2024 19:22:04 +0800 Subject: [PATCH 08/26] fix gather_point bug. --- .../furthest_point_sampling_with_dist_npu.cpp | 8 ++++---- mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp | 16 +++++++++++++--- mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp | 3 ++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp index 364d3bfa9a..24317a06bb 100644 --- a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp @@ -6,11 +6,11 @@ void furthest_point_sampling_with_dist_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, int b, int n, int m) { - auto output_size = {b, m}; - at::Tensor result = - at::empty(output_size, points_tensor.options().dtype(at::kInt)); + TORCH_CHECK( + (points_tensor.sizes()[1] >= m), + "the num of sampled points should smaller than total num of points."); EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, - m, result); + m, idx_tensor); } void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor, diff --git a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp index cf3a577ce1..991e6038db 100644 --- a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp @@ -24,6 +24,12 @@ void gather_points_forward_npu(int b, int c, int n, int npoints, void gather_points_backward_npu(int b, int c, int n, int npoints, const Tensor grad_out, const Tensor idx, Tensor grad_points) { + at::Tensor grad_out_cast = grad_out; + at::Tensor grad_points_cast = grad_points; + if (grad_out.scalar_type() == at::ScalarType::Half) { + grad_out_cast = grad_out.to(at::kFloat); + grad_points_cast = grad_points.to(at::kFloat); + } at::Tensor indices = idx; if (idx.scalar_type() != at::ScalarType::Int) { indices = idx.to(at::kInt); @@ -37,11 +43,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, for (uint64_t i = 0; i < shape.size(); i++) { pad_size.emplace_back(shape[i]); } - at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous(); + at::Tensor trans_grad_points = grad_points_cast.transpose(1, 2).contiguous(); at::Tensor grad_points_view = trans_grad_points.view( {trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1], trans_grad_points.sizes()[2]}); - at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous(); + at::Tensor trans_grad_out = grad_out_cast.transpose(1, 2).contiguous(); trans_grad_out = trans_grad_out.view( {trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1], trans_grad_out.sizes()[2]}); @@ -63,7 +69,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, at::Tensor grad_points_result = grad_points_view.view(trans_grad_points.sizes()); grad_points_result = grad_points_result.transpose(1, 2); - grad_points.copy_(grad_points_result); + at::Tensor grad_points_result_cast = grad_points_result; + if (grad_out.scalar_type() == at::ScalarType::Half) { + grad_points_result_cast = grad_points_result.to(at::kHalf); + } + grad_points.copy_(grad_points_result_cast); } void gather_points_forward_impl(int b, int c, int n, int npoints, diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index 2b3af2575c..b7015439b9 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -70,7 +70,8 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, .Attr("spatial_scale_w", spatial_scale) .Attr("pool_channel", pooled_channel) .Run(); - at::Tensor res = NpuUtils::format_contiguous(result); + at::Tensor result = y.transpose(2, 3).transpose(1, 2); + at::Tensor res = result.contiguous(); grad_input.copy_(res); } From d90969b7e525b39c236c141ac7d2618785318345 Mon Sep 17 00:00:00 2001 From: ZYF-Annarine Date: Thu, 6 Jun 2024 15:17:45 +0800 Subject: [PATCH 09/26] chamfer_distance fp16->fp32 --- .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index f34d4289c4..0f9f099901 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -1,4 +1,3 @@ - #include "pytorch_npu_helper.hpp" using namespace NPU_NAME_SPACE; @@ -6,19 +5,34 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { + bool is_half = input.scalar_type() == at::kHalf; at::Tensor xyz1 = at::ones_like(XYZ1); at::Tensor xyz2 = at::ones_like(XYZ2); + at::Tensor distf1 = at::ones_like(dist1); + at::Tensor distf2 = at::ones_like(dist2); xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); + if (is_half) { + xyz1 = xyz1.to(at::kFloat); + xyz2 = xyz2.to(at::kFloat); + distf1 = dist1.to(at::kFloat); + distf2 = dist2.to(at::kFloat); + } OpCommand cmd; cmd.Name("ChamferDistance") .Input(xyz1) .Input(xyz2) - .Output(dist1) - .Output(dist2) + .Output(distf1) + .Output(distf2) .Output(idx1) .Output(idx2) .Run(); + if (is_half) { + distf1 = distf1.to(at::kHalf); + distf2 = distf2.to(at::kHalf); + } + dist1.copy_(distf1); + dist2.copy_(distf2); } void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, From e4e8f503b7ba732fc4c914614faef33b2c4bd4bc Mon Sep 17 00:00:00 2001 From: ZYF-Annarine Date: Thu, 6 Jun 2024 15:35:46 +0800 Subject: [PATCH 10/26] chamfer_distance fp16->fp32 --- .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 2 +- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 105 ++++++++++++++---- mmcv/ops/points_in_boxes.py | 5 +- 3 files changed, 89 insertions(+), 23 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 0f9f099901..4f5c32dbec 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -5,7 +5,7 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { - bool is_half = input.scalar_type() == at::kHalf; + bool is_half = XYZ1.scalar_type() == at::kHalf; at::Tensor xyz1 = at::ones_like(XYZ1); at::Tensor xyz2 = at::ones_like(XYZ2); at::Tensor distf1 = at::ones_like(dist1); diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 5030fed0e7..3f3bc5a047 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -4,6 +4,21 @@ using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor output_y = output; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + output_y = output.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { @@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, target_y = at::add(target_y, 1.0); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } target_y = target_y.to(at::kInt); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(output_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + output_y = output_y.to(at::kHalf); + } + output.copy_(output_y); } void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { target_y = at::reshape(target, input.sizes()); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); } target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, @@ -74,19 +108,30 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } - at::Tensor op_output = at::ones_like(input); + at::Tensor op_output = at::ones_like(input_y); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) .Output(op_output) @@ -94,6 +139,9 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + op_output = op_output.to(at::kHalf); + } int64_t n_batch = input.size(0); c10::SmallVector offsets = {0, 0}; c10::SmallVector sizes = {n_batch, 1}; @@ -124,27 +172,44 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void softmax_focal_loss_backward_impl(Tensor input, Tensor target, diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 4915e6b573..94c8dad33e 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -47,8 +47,9 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) ext_module.points_in_boxes_part_forward(boxes.contiguous(), points.contiguous(), From 4d7bf263bd5f4437232b1f6823855b7a4594771a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E5=AE=8F=E7=AC=8B?= Date: Fri, 17 May 2024 11:32:50 +0800 Subject: [PATCH 11/26] update constraints of points_in_polygons --- mmcv/ops/points_in_polygons.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mmcv/ops/points_in_polygons.py b/mmcv/ops/points_in_polygons.py index 25b641d961..8d3bc8dd48 100644 --- a/mmcv/ops/points_in_polygons.py +++ b/mmcv/ops/points_in_polygons.py @@ -19,9 +19,8 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor: polygons (torch.Tensor): It has shape (M, 8), indicating (x1, y1, x2, y2, x3, y3, x4, y4). M means the number of ground truth polygons. - constraints: The absolute value of input range from e-10 to - e10 on NPU. - + constraints: The number of significant digits for the input-arguments + are between -10 and 10 when running on Ascend device. Returns: torch.Tensor: Return the result with the shape of (B, M), From c38593d25ec5bab75c4b4c8ad8ca4ac83719ac1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E5=AE=8F=E7=AC=8B?= Date: Fri, 14 Jun 2024 17:11:39 +0800 Subject: [PATCH 12/26] repair nms_rotated bug --- mmcv/ops/nms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 5115a95f62..d0c761ce39 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -452,7 +452,7 @@ def nms_rotated(dets: Tensor, flip_mat[-1] = -1 dets_cw = dets * flip_mat else: - dets_cw = dets + dets_cw = dets.clone() multi_label = labels is not None if labels is None: input_labels = scores.new_empty(0, dtype=torch.int) @@ -462,6 +462,8 @@ def nms_rotated(dets: Tensor, order = scores.new_empty(0, dtype=torch.long) if dets.device.type == 'npu': coefficient = 57.29578 # 180 / PI + dets_cw = dets_cw.float() + scores = scores.float() for i in range(dets.size()[0]): dets_cw[i][4] *= coefficient # radians to angle keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, From 38da28043d5d16a3c83415454d17ffccdb6bbea2 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Mon, 17 Jun 2024 09:22:51 +0800 Subject: [PATCH 13/26] fix three_interplote bug. --- .../pytorch/npu/three_interpolate_npu.cpp | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index f908755478..42d346f7d2 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_forward ascend only support fp32 and fp16."); - auto point_c_trans = points.transpose(1, 2); - + auto point_c_trans = points.transpose(1, 2).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + auto out_cast = out.to(at::kFloat); OpCommand cmd; cmd.Name("ThreeInterpolate") .Input(point_c_trans) .Input(idx) - .Input(weight) - .Output(out) + .Input(weight_cast) + .Output(out_cast) .Run(); - auto output = out.view({b, n, c}).transpose(1, 2); + if (originDtype == at::kHalf) { + out_cast = out_cast.to(at::kHalf); + } + auto output = out_cast.view({b, n, c}).transpose(1, 2); auto res = output.contiguous(); out.copy_(res); } @@ -34,12 +38,17 @@ void three_interpolate_backward_npu(int b, int c, int n, int m, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_backward ascend only support fp32 and fp16."); - auto grad_x = at::unsqueeze(grad_out, 3); - auto grad_y = at::unsqueeze(grad_points, 3); - - EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y); + auto grad_x = at::unsqueeze(grad_out, 3).to(at::kFloat); + auto grad_y = at::unsqueeze(grad_points, 3).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight_cast, m, + grad_y); - auto output = at::squeeze(grad_y, 3); + auto grad_y_cast = grad_y; + if (originDtype == at::kHalf) { + grad_y_cast = grad_y.to(at::kHalf); + } + auto output = at::squeeze(grad_y_cast, 3); auto res = output.contiguous(); grad_points.copy_(res); } From 95af19397f5774de6cdd12011fd93a28333761a8 Mon Sep 17 00:00:00 2001 From: lizekai Date: Fri, 14 Jun 2024 15:54:46 +0800 Subject: [PATCH 14/26] npu knn/tnn bugfix --- mmcv/ops/csrc/pytorch/npu/knn_npu.cpp | 21 ++++++++++++++++++++ mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp | 20 +++++++++++++++++++ mmcv/ops/knn.py | 23 ++++++++++++++++++++-- mmcv/ops/three_nn.py | 15 ++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/knn_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp new file mode 100644 index 0000000000..c4a1bcbd25 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp @@ -0,0 +1,21 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + // transpose known from [B, N, 3] to [B, 3, N] + at::Tensor source = xyz.transpose(2, 1).contiguous(); + at::Tensor target = new_xyz.contiguous(); + + bool is_from_knn = true; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); +} + +void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2); + +REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp new file mode 100644 index 0000000000..6740a731bc --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp @@ -0,0 +1,20 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void three_nn_forward_npu(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + at::Tensor source = known.contiguous(); + at::Tensor target = unknown.contiguous(); + + bool is_from_knn = false; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); +} + +void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx); + +REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu); diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index 48ce92f925..1e2a68d1d2 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -55,12 +55,31 @@ def forward(ctx, center_xyz_device = center_xyz.get_device() assert center_xyz_device == xyz.get_device(), \ 'center_xyz and xyz should be put on the same device' - if torch.cuda.current_device() != center_xyz_device: - torch.cuda.set_device(center_xyz_device) + if xyz.device.type != 'npu': + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) B, npoint, _ = center_xyz.shape N = xyz.shape[1] + if xyz.device.type == 'npu': + dist = center_xyz.new_zeros((B, npoint, N)).float() + ext_module.knn_forward( + xyz, + center_xyz, + torch.Tensor([]).npu(), + dist, + b=B, + n=N, + m=npoint, + nsample=k) + dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True) + zeros_idx = torch.zeros( + xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu() + idx.where(dist2 >= 1e10, zeros_idx) + idx = idx.transpose(2, 1).contiguous() # [B, k, npoint] + return idx.int() + idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index d41b9789cf..52d504609a 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor, B, N, _ = target.size() m = source.size(1) + if source.device.type == 'npu': + # strict to fp32 + source = source.transpose(2, 1).contiguous() + dtype_ = source.dtype + if dtype_ == torch.float16: + target = target.float() + source = source.float() + dist = target.new_empty(B, N, m) + ext_module.three_nn_forward( + target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m) + dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True) + dist2 = torch.sqrt(dist2) + if dtype_ == torch.float16: + dist2 = dist2.half() + return dist2, idx.int() dist2 = target.new_empty(B, N, 3) idx = target.new_empty(B, N, 3, dtype=torch.int32) From 859a2cc4babcbe393e66bdd67620749239483ff0 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Mon, 17 Jun 2024 16:24:35 +0800 Subject: [PATCH 15/26] fix pointsinbox bug --- mmcv/ops/points_in_boxes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 94c8dad33e..23c35da4eb 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -50,6 +50,8 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: if points.device.type != 'npu': if torch.cuda.current_device() != points_device: torch.cuda.set_device(points_device) + elif points.device.type == 'npu': + boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_part_forward(boxes.contiguous(), points.contiguous(), From 679a54a95e22cc12a1d3aa1b211de4b003db92c4 Mon Sep 17 00:00:00 2001 From: 15267151901 Date: Mon, 17 Jun 2024 21:08:15 +0800 Subject: [PATCH 16/26] fix deformConv and modulatedDeformConv input kernel_size --- mmcv/ops/deform_conv.py | 2 +- mmcv/ops/modulated_deform_conv.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 73472dc9b1..6db4ddd2f6 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -54,7 +54,7 @@ def _npu_backward(ctx, grad_output): grad_input, grad_weight, grad_offset_all, grad_bias = \ torch.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], + kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 4c735e2a09..f66822771d 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -53,7 +53,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): conv2d_bias = bias if len(bias) > 0 else None sort_index_fp, sort_index_bp = \ ModulatedDeformConv2dFunction._calculate_sort_index( - kernel_w, kernel_h, ctx.deform_groups) + kernel_h, kernel_w, ctx.deform_groups) select_offset = offset.index_select(1, sort_index_fp) offset_all = torch.cat([select_offset, mask], dim=1) output, offset_out = torch.npu_deformable_conv2d( @@ -61,7 +61,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): weight, offset_all, conv2d_bias, - kernel_size=[kernel_w, kernel_h], + kernel_size=[kernel_h, kernel_w], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1] @@ -83,7 +83,7 @@ def _npu_backward(ctx, grad_output): grad_input, grad_weight, grad_offset_all, grad_bias = \ torch.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], + kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], From 85d0ce4576f93705a15af999ff843f80bf0ea1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E5=AE=8F=E7=AC=8B?= Date: Tue, 18 Jun 2024 09:35:25 +0800 Subject: [PATCH 17/26] repair nms_rotated bug --- mmcv/ops/nms.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index d0c761ce39..1c3150dea2 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -452,23 +452,37 @@ def nms_rotated(dets: Tensor, flip_mat[-1] = -1 dets_cw = dets * flip_mat else: - dets_cw = dets.clone() + dets_cw = dets multi_label = labels is not None if labels is None: input_labels = scores.new_empty(0, dtype=torch.int) else: input_labels = labels - if dets.device.type in ('npu', 'mlu'): + + if dets.device.type == 'mlu': + order = scores.new_empty(0, dtype=torch.long) + keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, + input_labels, iou_threshold, + multi_label) + dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), + dim=1) + return dets, keep_inds + + if dets.device.type == 'npu': order = scores.new_empty(0, dtype=torch.long) - if dets.device.type == 'npu': - coefficient = 57.29578 # 180 / PI + coefficient = 57.29578 # 180 / PI + if dets.dtype == torch.float16: dets_cw = dets_cw.float() - scores = scores.float() - for i in range(dets.size()[0]): - dets_cw[i][4] *= coefficient # radians to angle + else: + dets_cw = dets_cw.clone() + for i in range(dets.size()[0]): + dets_cw[i][4] *= coefficient # radians to angle + scores = scores.float() keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, input_labels, iou_threshold, multi_label) + if dets.dtype == torch.float16: + scores = scores.half() dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), dim=1) return dets, keep_inds From 0852bb22838e9dacd1a9aad2c2cf3e54dde286a7 Mon Sep 17 00:00:00 2001 From: ZYF-Annarine Date: Wed, 19 Jun 2024 14:48:37 +0800 Subject: [PATCH 18/26] modify chamfer --- .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 4f5c32dbec..170a5fa72a 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -6,17 +6,17 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { bool is_half = XYZ1.scalar_type() == at::kHalf; - at::Tensor xyz1 = at::ones_like(XYZ1); - at::Tensor xyz2 = at::ones_like(XYZ2); - at::Tensor distf1 = at::ones_like(dist1); - at::Tensor distf2 = at::ones_like(dist2); + at::Tensor xyz1 = XYZ1; + at::Tensor xyz2 = XYZ2; + at::Tensor distf1 = dist1; + at::Tensor distf2 = dist2; xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); if (is_half) { xyz1 = xyz1.to(at::kFloat); xyz2 = xyz2.to(at::kFloat); - distf1 = dist1.to(at::kFloat); - distf2 = dist2.to(at::kFloat); + distf1 = distf1.to(at::kFloat); + distf2 = distf2.to(at::kFloat); } OpCommand cmd; cmd.Name("ChamferDistance") @@ -31,8 +31,8 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, distf1 = distf1.to(at::kHalf); distf2 = distf2.to(at::kHalf); } - dist1.copy_(distf1); - dist2.copy_(distf2); + dist1 = distf1; + dist2 = distf2; } void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, From 5a75cd1b49a6621bff45d6e97bcca3005cbb09bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=8C=AF=E8=B1=AA?= Date: Fri, 21 Jun 2024 17:59:28 +0800 Subject: [PATCH 19/26] Bugfix of NPU adapter of nms3d --- mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp | 17 +++++++++-------- mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp | 18 +++++++++++------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp index 5d812fe047..6d2588a01d 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp @@ -3,23 +3,24 @@ using namespace NPU_NAME_SPACE; void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep, - Tensor &keep_num, - float nms_overlap_thresh) { + Tensor &num_out, float nms_overlap_thresh) { int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + const double iou_threshold = nms_overlap_thresh; at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask); + EXEC_NPU_CMD(aclnnNms3dNormal, boxes, iou_threshold, mask); - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); + Tensor keep_t = at::zeros({box_num}, mask.options()); + Tensor num_out_t = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t); + num_out.fill_(num_out_t.item().toLong()); + keep.copy_(keep_t); } void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, - float nms_overlap_thresh); + Tensor &num_out, float nms_overlap_thresh); REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp index 13fe6db860..a143ed07b5 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp @@ -5,22 +5,26 @@ using namespace std; constexpr int32_t BOX_DIM = 7; -void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num, +void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &num_out, float nms_overlap_thresh) { TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)"); int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + const double iou_threshold = nms_overlap_thresh; at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask); - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); + EXEC_NPU_CMD(aclnnNms3d, boxes, iou_threshold, mask); + + Tensor keep_t = at::zeros({box_num}, mask.options()); + Tensor num_out_t = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t); + num_out.fill_(num_out_t.item().toLong()); + keep.copy_(keep_t); } -void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, float nms_overlap_thresh); +void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, Tensor &num_out, + float nms_overlap_thresh); REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu); From 323fbbb20e62927d2259a1ed644374b17d4c6f79 Mon Sep 17 00:00:00 2001 From: wujiadi Date: Tue, 25 Jun 2024 21:21:41 +0800 Subject: [PATCH 20/26] fix the bug of DeformableRoiPoolGrad --- mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp b/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp index 074e52d4f4..42de978e88 100644 --- a/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp +++ b/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp @@ -53,7 +53,7 @@ void deform_roi_pool_backward_npu(Tensor grad_output, Tensor input, Tensor rois, .Output(grad_offset) .Attr("output_size", output_size) .Attr("spatial_scale", spatial_scale) - .Attr("sample_ratio", sampling_ratio_) + .Attr("sampling_ratio", sampling_ratio_) .Attr("gamma", gamma) .Run(); } From 02d23c0735c53315a6ff795a8df19223d49bfc15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=8C=AF=E8=B1=AA?= Date: Fri, 19 Jul 2024 16:10:04 +0800 Subject: [PATCH 21/26] Interfaces change. --- mmcv/ops/scatter_points.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index 5d881bfe63..c930326cee 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -36,6 +36,26 @@ def forward(ctx: Any, reduced from input features that share the same voxel coordinates. The second is voxel coordinates with shape [M, ndim]. """ + ctx.device = feats.device.type + if ctx.device == 'npu': + import ads_c + voxel_idx = ads_c.point_to_voxel(coors, [], []) + unique_res = ads_c.unique_voxel(voxel_idx) + num_voxels, uniqued_voxel_idx, prefix_sum, \ + argsort_coor, _ = unique_res + voxel_coors = ads_c.voxel_to_point(uniqued_voxel_idx, [], []) + voxel_feats, \ + compare_mask = ads_c.npu_dynamic_scatter(feats, coors, + prefix_sum, + argsort_coor, + num_voxels, + reduce_type) + ctx.reduce_type = reduce_type + ctx.feats_shape = feats.shape + ctx.save_for_backward(prefix_sum, argsort_coor, compare_mask) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + results = ext_module.dynamic_point_to_voxel_forward( feats, coors, reduce_type) (voxel_feats, voxel_coors, point2voxel_map, @@ -50,6 +70,19 @@ def forward(ctx: Any, def backward(ctx: Any, grad_voxel_feats: torch.Tensor, grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple: + if ctx.device == 'npu': + import ads_c + prefix_sum, argsort_coor, compare_mask = ctx.saved_tensors + grad_point_feats = torch.zeros( + ctx.feats_shape, + dtype=grad_voxel_feats.dtype, + device=grad_voxel_feats.device) + ads_c.npu_dynamic_scatter_grad(grad_point_feats, + grad_voxel_feats.contiguous(), + prefix_sum, argsort_coor, + compare_mask, ctx.reduce_type) + return grad_point_feats, None, None + (feats, voxel_feats, point2voxel_map, voxel_points_count) = ctx.saved_tensors grad_feats = torch.zeros_like(feats) From cd324df57557dc94a2941ecd062f6e9b499eeceb Mon Sep 17 00:00:00 2001 From: ZYF-Annarine Date: Mon, 22 Jul 2024 15:59:41 +0800 Subject: [PATCH 22/26] chamfer push_back --- .../csrc/pytorch/npu/chamfer_distance_npu.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 170a5fa72a..9345da6dec 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -6,17 +6,17 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { bool is_half = XYZ1.scalar_type() == at::kHalf; - at::Tensor xyz1 = XYZ1; - at::Tensor xyz2 = XYZ2; - at::Tensor distf1 = dist1; - at::Tensor distf2 = dist2; + at::Tensor xyz1 = at::ones_like(XYZ1); + at::Tensor xyz2 = at::ones_like(XYZ2); + at::Tensor distf1 = at::ones_like(dist1); + at::Tensor distf2 = at::ones_like(dist2); xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); if (is_half) { xyz1 = xyz1.to(at::kFloat); xyz2 = xyz2.to(at::kFloat); - distf1 = distf1.to(at::kFloat); - distf2 = distf2.to(at::kFloat); + distf1 = dist1.to(at::kFloat); + distf2 = dist2.to(at::kFloat); } OpCommand cmd; cmd.Name("ChamferDistance") @@ -31,10 +31,11 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, distf1 = distf1.to(at::kHalf); distf2 = distf2.to(at::kHalf); } - dist1 = distf1; - dist2 = distf2; + dist1.copy_(distf1); + dist2.copy_(distf2); } + void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, Tensor idx2, Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, From 58cd6f6eec0787e7c505e085ba656c34b3635b77 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Mon, 5 Aug 2024 11:41:47 +0800 Subject: [PATCH 23/26] fix msda Update chamfer_distance_npu.cpp --- mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp | 1 - mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp | 12 +++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 9345da6dec..4f5c32dbec 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -35,7 +35,6 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, dist2.copy_(distf2); } - void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, Tensor idx2, Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, diff --git a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp index 7e943ca12f..1ad3ce3f91 100644 --- a/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp @@ -57,15 +57,9 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value, value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)}; at::Tensor output = at::zeros(output_size, value_fp32.options()); - OpCommand cmd; - cmd.Name("MultiScaleDeformableAttnFunction") - .Input(value_fp32) - .Input(value_spatial_shapes_int32) - .Input(value_level_start_index_int32) - .Input(sampling_locations_fp32) - .Input(attention_weights_fp32) - .Output(output) - .Run(); + EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunction, value_fp32, + value_spatial_shapes_int32, value_level_start_index_int32, + sampling_locations_fp32, attention_weights_fp32, output); at::Tensor real_output = output; if (value.scalar_type() != at::kFloat) { From 73fbb2a208cf8e7347c9a3f22300d98b5caac4ad Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Wed, 7 Aug 2024 17:01:39 +0800 Subject: [PATCH 24/26] fix softmax_focal_loss_grad --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 3f3bc5a047..ef7df560c9 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -194,6 +194,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at::sum(weight_y, 1, true); weight_y = at::broadcast_to(weight_y, input.sizes()); } + grad_input_y = grad_input_y.fill_(0); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") From f2990014ad1059c069783b44e6765166a4545ffb Mon Sep 17 00:00:00 2001 From: hust17yixuan <303660421@qq.com> Date: Mon, 9 Dec 2024 15:43:08 +0800 Subject: [PATCH 25/26] fix bug --- mmcv/ops/deform_conv.py | 3 ++- mmcv/ops/modulated_deform_conv.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 6db4ddd2f6..78f32eb6dc 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -49,10 +49,11 @@ def symbolic(g, @staticmethod def _npu_backward(ctx, grad_output): + import torch_npu input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ ctx.saved_tensors grad_input, grad_weight, grad_offset_all, grad_bias = \ - torch.npu_deformable_conv2dbk( + torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index f66822771d..83c9544e7f 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -80,8 +80,9 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): def _npu_backward(ctx, grad_output): input_tensor, weight, offset_out, offset_all, sort_index_bp = \ ctx.saved_tensors + import torch_npu grad_input, grad_weight, grad_offset_all, grad_bias = \ - torch.npu_deformable_conv2dbk( + torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], From e9c17c11e218aeea4fb5e61a8d8db982da06de00 Mon Sep 17 00:00:00 2001 From: hust17yixuan <303660421@qq.com> Date: Mon, 9 Dec 2024 15:49:49 +0800 Subject: [PATCH 26/26] fix bug forward --- mmcv/ops/modulated_deform_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 83c9544e7f..7796044e8c 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -56,7 +56,8 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): kernel_h, kernel_w, ctx.deform_groups) select_offset = offset.index_select(1, sort_index_fp) offset_all = torch.cat([select_offset, mask], dim=1) - output, offset_out = torch.npu_deformable_conv2d( + import torch_npu + output, offset_out = torch_npu.npu_deformable_conv2d( input_tensor, weight, offset_all,