Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix npu op bugs. #3127

Open
wants to merge 35 commits into
base: 1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2a1967d
add ms_deform_attn
momo609 Feb 4, 2024
846f5b5
add points_in_box
momo609 Feb 5, 2024
98393a3
fix bug.
momo609 Feb 29, 2024
56cda46
add multi npu op
huaweiZJX Nov 29, 2023
14f6bdc
add multi npu op.
huaweiZJX Nov 29, 2023
bfffaeb
add constraints of pointinpolygon
momo609 May 13, 2024
abf8ca7
fix roi_pool bug.
momo609 May 28, 2024
737b5b4
fix gather_point bug.
momo609 May 30, 2024
d90969b
chamfer_distance fp16->fp32
Annarine Jun 6, 2024
e4e8f50
chamfer_distance fp16->fp32
Annarine Jun 6, 2024
4d7bf26
update constraints of points_in_polygons
huhongsun May 17, 2024
c38593d
repair nms_rotated bug
huhongsun Jun 14, 2024
38da280
fix three_interplote bug.
momo609 Jun 17, 2024
95af193
npu knn/tnn bugfix
Jun 14, 2024
128f29a
Merge pull request #25 from huhongsun/rc41.x
momo609 Jun 17, 2024
859a2cc
fix pointsinbox bug
momo609 Jun 17, 2024
fca25b1
Merge pull request #23 from Binary2355/rc41.x
momo609 Jun 17, 2024
679a54a
fix deformConv and modulatedDeformConv input kernel_size
wuzheyi1028 Jun 17, 2024
c94c723
Merge pull request #30 from wuzheyi1028/wzy/rc41.x
momo609 Jun 18, 2024
85d0ce4
repair nms_rotated bug
huhongsun Jun 18, 2024
0852bb2
modify chamfer
Annarine Jun 19, 2024
0356569
Merge pull request #32 from huhongsun/rc41.x
momo609 Jun 20, 2024
5a75cd1
Bugfix of NPU adapter of nms3d
DaGaiBa Jun 21, 2024
3bc1aa1
Merge pull request #40 from DaGaiBa/rc41.x
momo609 Jun 22, 2024
323fbbb
fix the bug of DeformableRoiPoolGrad
wujiadi1 Jun 25, 2024
4482059
Merge pull request #44 from wujiadi1/rc41.x
momo609 Jun 26, 2024
02d23c0
Interfaces change.
DaGaiBa Jul 19, 2024
815c6a6
Merge pull request #52 from DaGaiBa/rc41.x
momo609 Jul 19, 2024
cd324df
chamfer push_back
Annarine Jul 22, 2024
58cd6f6
fix msda
Pr0Wh1teGivee Aug 5, 2024
6dd2b80
Merge pull request #55 from Pr0Wh1teGivee/msda_fix_rc41
momo609 Aug 5, 2024
73fbb2a
fix softmax_focal_loss_grad
Pr0Wh1teGivee Aug 7, 2024
e39eb62
Merge pull request #56 from Pr0Wh1teGivee/softmax_fix_rc41
momo609 Aug 7, 2024
f299001
fix bug
hust17yixuan Dec 9, 2024
e9c17c1
fix bug forward
hust17yixuan Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ MMCV 提供了检测、分割等任务中常用的算子
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | √ | | |
| DynamicScatter | | √ | √ | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | √ |
| GatherPoints | | √ | | | √ |
| GroupPoints | | √ | | | |
Expand Down
20 changes: 17 additions & 3 deletions mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@

#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
Tensor dist2, Tensor idx1, Tensor idx2) {
bool is_half = XYZ1.scalar_type() == at::kHalf;
at::Tensor xyz1 = at::ones_like(XYZ1);
at::Tensor xyz2 = at::ones_like(XYZ2);
at::Tensor distf1 = at::ones_like(dist1);
at::Tensor distf2 = at::ones_like(dist2);
xyz1 = XYZ1.transpose(1, 2).transpose(0, 1);
xyz2 = XYZ2.transpose(1, 2).transpose(0, 1);
if (is_half) {
xyz1 = xyz1.to(at::kFloat);
xyz2 = xyz2.to(at::kFloat);
distf1 = dist1.to(at::kFloat);
distf2 = dist2.to(at::kFloat);
}
OpCommand cmd;
cmd.Name("ChamferDistance")
.Input(xyz1)
.Input(xyz2)
.Output(dist1)
.Output(dist2)
.Output(distf1)
.Output(distf2)
.Output(idx1)
.Output(idx2)
.Run();
if (is_half) {
distf1 = distf1.to(at::kHalf);
distf2 = distf2.to(at::kHalf);
}
dist1.copy_(distf1);
dist2.copy_(distf2);
}

void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1,
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void deform_roi_pool_backward_npu(Tensor grad_output, Tensor input, Tensor rois,
.Output(grad_offset)
.Attr("output_size", output_size)
.Attr("spatial_scale", spatial_scale)
.Attr("sample_ratio", sampling_ratio_)
.Attr("sampling_ratio", sampling_ratio_)
.Attr("gamma", gamma)
.Run();
}
Expand Down
106 changes: 86 additions & 20 deletions mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@ using namespace std;

void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor input_y = input;
at::Tensor output_y = output;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
output_y = output.to(at::kFloat);
}
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
Expand All @@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
target_y = at::add(target_y, 1.0);
} else {
target_y = at::one_hot(target, n_class);
weight_y = at::mul(weight_y, target_y);
weight_y = at::sum(weight_y, 1, true);
weight_y = at::broadcast_to(weight_y, input.sizes());
}
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Output(output_y)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
output_y = output_y.to(at::kHalf);
}
output.copy_(output_y);
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Expand All @@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha) {
at::Tensor input_y = input;
at::Tensor grad_input_y = grad_input;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
grad_input_y = grad_input.to(at::kFloat);
}
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
} else {
target_y = at::one_hot(target, n_class);
weight_y = at::mul(weight_y, target_y);
weight_y = at::sum(weight_y, 1, true);
weight_y = at::broadcast_to(weight_y, input.sizes());
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
}
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Output(grad_input_y)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
grad_input_y = grad_input_y.to(at::kHalf);
}
grad_input.copy_(grad_input_y);
}

void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Expand All @@ -74,26 +108,40 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,

void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor input_y = input;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::one_hot(target, n_class);
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
weight_y = at::mul(weight_y, target_y);
weight_y = at::sum(weight_y, 1, true);
weight_y = at::broadcast_to(weight_y, input.sizes());
}
at::Tensor op_output = at::ones_like(input);
at::Tensor op_output = at::ones_like(input_y);
OpCommand cmd;
string reduction = "none";
cmd.Name("SoftmaxFocalLoss")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(weight_y)
.Output(op_output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
op_output = op_output.to(at::kHalf);
}
int64_t n_batch = input.size(0);
c10::SmallVector<int64_t, 2> offsets = {0, 0};
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
Expand Down Expand Up @@ -124,27 +172,45 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input,
float gamma, float alpha) {
at::Tensor input_y = input;
at::Tensor grad_input_y = grad_input;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
grad_input_y = grad_input.to(at::kFloat);
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::one_hot(target, n_class);
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
weight_y = at::mul(weight_y, target_y);
weight_y = at::sum(weight_y, 1, true);
weight_y = at::broadcast_to(weight_y, input.sizes());
}
grad_input_y = grad_input_y.fill_(0);
OpCommand cmd;
string reduction = "none";
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Output(grad_input_y)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
grad_input_y = grad_input_y.to(at::kHalf);
}
grad_input.copy_(grad_input_y);
}

void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Expand Down
23 changes: 23 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/furthest_point_sample_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_forward_npu(Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m) {
TORCH_CHECK(
(points_tensor.sizes()[1] >= m),
"the num of sampled points should smaller than total num of points.");
at::Tensor points_xyz = points_tensor.transpose(1, 2).contiguous();
at::Tensor nearest_dist = temp_tensor.contiguous();
EXEC_NPU_CMD(aclnnFurthestPointSampling, points_xyz, nearest_dist, m,
idx_tensor);
}

void furthest_point_sampling_forward_impl(Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor,
int b, int n, int m);

REGISTER_NPU_IMPL(furthest_point_sampling_forward_impl,
furthest_point_sampling_forward_npu);
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;

void furthest_point_sampling_with_dist_npu(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m) {
TORCH_CHECK(
(points_tensor.sizes()[1] >= m),
"the num of sampled points should smaller than total num of points.");
EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor,
m, idx_tensor);
}

void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b,
int n, int m);

REGISTER_NPU_IMPL(furthest_point_sampling_with_dist_forward_impl,
furthest_point_sampling_with_dist_npu);
16 changes: 13 additions & 3 deletions mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
void gather_points_backward_npu(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
at::Tensor grad_out_cast = grad_out;
at::Tensor grad_points_cast = grad_points;
if (grad_out.scalar_type() == at::ScalarType::Half) {
grad_out_cast = grad_out.to(at::kFloat);
grad_points_cast = grad_points.to(at::kFloat);
}
at::Tensor indices = idx;
if (idx.scalar_type() != at::ScalarType::Int) {
indices = idx.to(at::kInt);
Expand All @@ -37,11 +43,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints,
for (uint64_t i = 0; i < shape.size(); i++) {
pad_size.emplace_back(shape[i]);
}
at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous();
at::Tensor trans_grad_points = grad_points_cast.transpose(1, 2).contiguous();
at::Tensor grad_points_view = trans_grad_points.view(
{trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1],
trans_grad_points.sizes()[2]});
at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous();
at::Tensor trans_grad_out = grad_out_cast.transpose(1, 2).contiguous();
trans_grad_out = trans_grad_out.view(
{trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1],
trans_grad_out.sizes()[2]});
Expand All @@ -63,7 +69,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints,
at::Tensor grad_points_result =
grad_points_view.view(trans_grad_points.sizes());
grad_points_result = grad_points_result.transpose(1, 2);
grad_points.copy_(grad_points_result);
at::Tensor grad_points_result_cast = grad_points_result;
if (grad_out.scalar_type() == at::ScalarType::Half) {
grad_points_result_cast = grad_points_result.to(at::kHalf);
}
grad_points.copy_(grad_points_result_cast);
}

void gather_points_forward_impl(int b, int c, int n, int npoints,
Expand Down
21 changes: 21 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"

using namespace NPU_NAME_SPACE;
using namespace std;

void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(2, 1).contiguous();
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2);

REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
Loading