diff --git a/cpp/open3d/t/geometry/TriangleMesh.cpp b/cpp/open3d/t/geometry/TriangleMesh.cpp index d66c1ff91c2..85d76dc962f 100644 --- a/cpp/open3d/t/geometry/TriangleMesh.cpp +++ b/cpp/open3d/t/geometry/TriangleMesh.cpp @@ -1042,6 +1042,17 @@ static void CopyAttributesByMasks(TriangleMesh &dst, } TriangleMesh TriangleMesh::SelectFacesByMask(const core::Tensor &mask) const { + if (!HasVertexPositions()) { + utility::LogWarning( + "[SelectFacesByMask] mesh has no vertex positions."); + return {}; + } + if (!HasTriangleIndices()) { + utility::LogWarning( + "[SelectFacesByMask] mesh has no triangle indices."); + return {}; + } + core::AssertTensorShape(mask, {GetTriangleIndices().GetLength()}); core::AssertTensorDtype(mask, core::Bool); GetTriangleAttr().AssertSizeSynchronized(); @@ -1050,55 +1061,32 @@ TriangleMesh TriangleMesh::SelectFacesByMask(const core::Tensor &mask) const { // select triangles core::Tensor tris = GetTriangleIndices().IndexGet({mask}); core::Tensor tris_cpu = tris.To(core::Device()).Contiguous(); - const int64_t num_tris = tris_cpu.GetLength(); // create mask for vertices that are part of the selected faces const int64_t num_verts = GetVertexPositions().GetLength(); - core::Tensor vertex_mask = core::Tensor::Zeros({num_verts}, core::Int32); - std::vector prefix_sum(num_verts + 1, 0); - { - int32_t *vertex_mask_ptr = vertex_mask.GetDataPtr(); - if (tris_cpu.GetDtype() == core::Int32) { - int32_t *vert_idx_ptr = tris_cpu.GetDataPtr(); - for (int64_t i = 0; i < tris_cpu.GetLength() * 3; ++i) { - vertex_mask_ptr[vert_idx_ptr[i]] = 1; - } - } else { - int64_t *vert_idx_ptr = tris_cpu.GetDataPtr(); - for (int64_t i = 0; i < tris_cpu.GetLength() * 3; ++i) { - vertex_mask_ptr[vert_idx_ptr[i]] = 1; - } - } - utility::InclusivePrefixSum( - vertex_mask_ptr, vertex_mask_ptr + num_verts, &prefix_sum[1]); - } - - // update triangle indices - if (tris_cpu.GetDtype() == core::Int32) { - int32_t *vert_idx_ptr = tris_cpu.GetDataPtr(); - for (int64_t i = 0; i < num_tris * 3; ++i) { - int64_t new_idx = prefix_sum[vert_idx_ptr[i]]; - vert_idx_ptr[i] = int32_t(new_idx); - } - } else { - int64_t *vert_idx_ptr = tris_cpu.GetDataPtr(); + // empty tensor to further construct the vertex mask + core::Tensor vertex_mask; + + DISPATCH_INT_DTYPE_PREFIX_TO_TEMPLATE(tris_cpu.GetDtype(), tris, [&]() { + vertex_mask = core::Tensor::Zeros( + {num_verts}, core::Dtype::FromType()); + const int64_t num_tris = tris_cpu.GetLength(); + scalar_tris_t *vertex_mask_ptr = + vertex_mask.GetDataPtr(); + scalar_tris_t *vert_idx_ptr = tris_cpu.GetDataPtr(); + // mask for the vertices, which are used in the triangles for (int64_t i = 0; i < num_tris * 3; ++i) { - int64_t new_idx = prefix_sum[vert_idx_ptr[i]]; - vert_idx_ptr[i] = new_idx; + vertex_mask_ptr[vert_idx_ptr[i]] = 1; } - } + UpdateTriangleIndicesByVertexMask(tris_cpu, vertex_mask); + }); tris = tris_cpu.To(GetDevice()); vertex_mask = vertex_mask.To(GetDevice(), core::Bool); core::Tensor verts = GetVertexPositions().IndexGet({vertex_mask}); TriangleMesh result(verts, tris); - // copy attributes - for (auto item : GetVertexAttr()) { - if (!result.HasVertexAttr(item.first)) { - result.SetVertexAttr(item.first, - item.second.IndexGet({vertex_mask})); - } + CopyAttributesByMasks(result, *this, vertex_mask, mask); return result; } diff --git a/cpp/open3d/t/geometry/TriangleMesh.h b/cpp/open3d/t/geometry/TriangleMesh.h index 2cb5ffdd383..8f7373afd9d 100644 --- a/cpp/open3d/t/geometry/TriangleMesh.h +++ b/cpp/open3d/t/geometry/TriangleMesh.h @@ -927,7 +927,8 @@ class TriangleMesh : public Geometry, public DrawableGeometry { /// Returns a new mesh with the faces selected by a boolean mask. /// \param mask A boolean mask with the shape (N) with N as the number of /// faces in the mesh. - /// \return A new mesh with the selected faces. + /// \return A new mesh with the selected faces. If the original mesh is + /// empty, return an empty mesh. TriangleMesh SelectFacesByMask(const core::Tensor &mask) const; /// Returns a new mesh with the vertices selected by a vector of indices. diff --git a/cpp/pybind/t/geometry/trianglemesh.cpp b/cpp/pybind/t/geometry/trianglemesh.cpp index 20f2c79beaa..06cacf404a5 100644 --- a/cpp/pybind/t/geometry/trianglemesh.cpp +++ b/cpp/pybind/t/geometry/trianglemesh.cpp @@ -901,7 +901,7 @@ the partition id for each face. number of faces in the mesh. Returns: - A new mesh with the selected faces. + A new mesh with the selected faces. If the original mesh is empty, return an empty mesh. Example: diff --git a/cpp/tests/t/geometry/TriangleMesh.cpp b/cpp/tests/t/geometry/TriangleMesh.cpp index 50ec92a1d5f..67f37a7dd86 100644 --- a/cpp/tests/t/geometry/TriangleMesh.cpp +++ b/cpp/tests/t/geometry/TriangleMesh.cpp @@ -941,6 +941,113 @@ TEST_P(TriangleMeshPermuteDevices, CreateMobius) { triangle_indices_custom)); } +TEST_P(TriangleMeshPermuteDevices, SelectFacesByMask) { + // check that an exception is thrown if the mesh is empty + t::geometry::TriangleMesh mesh_empty; + core::Tensor mask_empty = + core::Tensor::Zeros({12}, core::Bool, mesh_empty.GetDevice()); + core::Tensor mask_full = + core::Tensor::Ones({12}, core::Bool, mesh_empty.GetDevice()); + + // check completely empty mesh + EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_empty).IsEmpty()); + EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_full).IsEmpty()); + + // check mesh w/o triangles + core::Tensor cpu_vertices = + core::Tensor::Ones({2, 3}, core::Float32, mesh_empty.GetDevice()); + mesh_empty.SetVertexPositions(cpu_vertices); + EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_empty).IsEmpty()); + EXPECT_TRUE(mesh_empty.SelectFacesByMask(mask_full).IsEmpty()); + + // create box with normals, colors and labels defined. + t::geometry::TriangleMesh box = t::geometry::TriangleMesh::CreateBox(); + core::Tensor vertex_colors = core::Tensor::Init({{0.0, 0.0, 0.0}, + {1.0, 1.0, 1.0}, + {2.0, 2.0, 2.0}, + {3.0, 3.0, 3.0}, + {4.0, 4.0, 4.0}, + {5.0, 5.0, 5.0}, + {6.0, 6.0, 6.0}, + {7.0, 7.0, 7.0}}); + ; + core::Tensor vertex_labels = core::Tensor::Init({{0.0, 0.0, 0.0}, + {1.0, 1.0, 1.0}, + {2.0, 2.0, 2.0}, + {3.0, 3.0, 3.0}, + {4.0, 4.0, 4.0}, + {5.0, 5.0, 5.0}, + {6.0, 6.0, 6.0}, + {7.0, 7.0, 7.0}}) * + 10; + ; + core::Tensor triangle_labels = + core::Tensor::Init({{0.0, 0.0, 0.0}, + {1.0, 1.0, 1.0}, + {2.0, 2.0, 2.0}, + {3.0, 3.0, 3.0}, + {4.0, 4.0, 4.0}, + {5.0, 5.0, 5.0}, + {6.0, 6.0, 6.0}, + {7.0, 7.0, 7.0}, + {8.0, 8.0, 8.0}, + {9.0, 9.0, 9.0}, + {10.0, 10.0, 10.0}, + {11.0, 11.0, 11.0}}) * + 100; + box.SetVertexColors(vertex_colors); + box.SetVertexAttr("labels", vertex_labels); + box.ComputeTriangleNormals(); + box.SetTriangleAttr("labels", triangle_labels); + + // empty index list + EXPECT_TRUE(box.SelectFacesByMask(mask_empty).IsEmpty()); + + // set the expected value + core::Tensor expected_verts = core::Tensor::Init({{0.0, 0.0, 1.0}, + {1.0, 0.0, 1.0}, + {0.0, 1.0, 1.0}, + {1.0, 1.0, 1.0}}); + core::Tensor expected_vert_colors = + core::Tensor::Init({{2.0, 2.0, 2.0}, + {3.0, 3.0, 3.0}, + {6.0, 6.0, 6.0}, + {7.0, 7.0, 7.0}}); + core::Tensor expected_vert_labels = + core::Tensor::Init({{20.0, 20.0, 20.0}, + {30.0, 30.0, 30.0}, + {60.0, 60.0, 60.0}, + {70.0, 70.0, 70.0}}); + core::Tensor expected_tris = + core::Tensor::Init({{0, 1, 3}, {0, 3, 2}}); + core::Tensor tris_mask = + core::Tensor::Init({0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0}); + core::Tensor expected_tri_normals = + box.GetTriangleNormals().IndexGet({tris_mask}); + core::Tensor expected_tri_labels = core::Tensor::Init( + {{800.0, 800.0, 800.0}, {900.0, 900.0, 900.0}}); + + // check basic case + t::geometry::TriangleMesh selected = box.SelectFacesByMask(tris_mask); + + EXPECT_TRUE(selected.GetVertexPositions().AllClose(expected_verts)); + EXPECT_TRUE(selected.GetVertexColors().AllClose(expected_vert_colors)); + EXPECT_TRUE( + selected.GetVertexAttr("labels").AllClose(expected_vert_labels)); + EXPECT_TRUE(selected.GetTriangleIndices().AllClose(expected_tris)); + EXPECT_TRUE(selected.GetTriangleNormals().AllClose(expected_tri_normals)); + EXPECT_TRUE( + selected.GetTriangleAttr("labels").AllClose(expected_tri_labels)); + + // Check that initial mesh is unchanged. + t::geometry::TriangleMesh box_untouched = + t::geometry::TriangleMesh::CreateBox(); + EXPECT_TRUE(box.GetVertexPositions().AllClose( + box_untouched.GetVertexPositions())); + EXPECT_TRUE(box.GetTriangleIndices().AllClose( + box_untouched.GetTriangleIndices())); +} + TEST_P(TriangleMeshPermuteDevices, SelectByIndex) { // check that an exception is thrown if the mesh is empty t::geometry::TriangleMesh mesh_empty; diff --git a/python/test/t/geometry/test_trianglemesh.py b/python/test/t/geometry/test_trianglemesh.py index e8e54abd3d7..2a108adff56 100644 --- a/python/test/t/geometry/test_trianglemesh.py +++ b/python/test/t/geometry/test_trianglemesh.py @@ -419,6 +419,96 @@ def test_pickle(device): mesh.triangle.indices.cpu().numpy()) +@pytest.mark.parametrize("device", list_devices()) +def test_select_faces_by_mask_32(device): + sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere( + 1, 3, o3c.float64, o3c.int32, device) + + expected_verts = o3c.Tensor( + [[0.0, 0.0, 1.0], [0.866025, 0, 0.5], [0.433013, 0.75, 0.5], + [-0.866025, 0.0, 0.5], [-0.433013, -0.75, 0.5], [0.433013, -0.75, 0.5] + ], o3c.float64, device) + + expected_tris = o3c.Tensor([[0, 1, 2], [0, 3, 4], [0, 4, 5], [0, 5, 1]], + o3c.int32, device) + + # check indices shape mismatch + mask_2d = o3c.Tensor([[False, False], [False, False], [False, False]], + o3c.bool, device) + with pytest.raises(RuntimeError): + selected = sphere_custom.select_faces_by_mask(mask_2d) + + # check indices type mismatch + mask_float = o3c.Tensor([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ], o3c.float32, device) + with pytest.raises(RuntimeError): + selected = sphere_custom.select_faces_by_mask(mask_float) + + # check the basic case + mask = o3c.Tensor([ + True, False, False, False, False, False, True, False, True, False, True, + False, False, False, False, False, False, False, False, False, False, + False, False, False + ], o3c.bool, device) + selected = sphere_custom.select_faces_by_mask(mask) + assert selected.vertex.positions.allclose(expected_verts) + assert selected.triangle.indices.allclose(expected_tris) + + # check that the original mesh is unmodified + untouched_sphere = o3d.t.geometry.TriangleMesh.create_sphere( + 1, 3, o3c.float64, o3c.int32, device) + assert sphere_custom.vertex.positions.allclose( + untouched_sphere.vertex.positions) + assert sphere_custom.triangle.indices.allclose( + untouched_sphere.triangle.indices) + + +@pytest.mark.parametrize("device", list_devices()) +def test_select_faces_by_mask_64(device): + sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere( + 1, 3, o3c.float64, o3c.int64, device) + + # check indices shape mismatch + mask_2d = o3c.Tensor([[False, False], [False, False], [False, False]], + o3c.bool, device) + with pytest.raises(RuntimeError): + selected = sphere_custom.select_faces_by_mask(mask_2d) + + # check indices type mismatch + mask_float = o3c.Tensor([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ], o3c.float32, device) + with pytest.raises(RuntimeError): + selected = sphere_custom.select_faces_by_mask(mask_float) + + expected_verts = o3c.Tensor( + [[0.0, 0.0, 1.0], [0.866025, 0, 0.5], [0.433013, 0.75, 0.5], + [-0.866025, 0.0, 0.5], [-0.433013, -0.75, 0.5], [0.433013, -0.75, 0.5] + ], o3c.float64, device) + + expected_tris = o3c.Tensor([[0, 1, 2], [0, 3, 4], [0, 4, 5], [0, 5, 1]], + o3c.int64, device) + # check the basic case + mask = o3c.Tensor([ + True, False, False, False, False, False, True, False, True, False, True, + False, False, False, False, False, False, False, False, False, False, + False, False, False + ], o3c.bool, device) + + selected = sphere_custom.select_faces_by_mask(mask) + assert selected.vertex.positions.allclose(expected_verts) + assert selected.triangle.indices.allclose(expected_tris) + + # check that the original mesh is unmodified + untouched_sphere = o3d.t.geometry.TriangleMesh.create_sphere( + 1, 3, o3c.float64, o3c.int64, device) + assert sphere_custom.vertex.positions.allclose( + untouched_sphere.vertex.positions) + assert sphere_custom.triangle.indices.allclose( + untouched_sphere.triangle.indices) + + @pytest.mark.parametrize("device", list_devices()) def test_select_by_index_32(device): sphere_custom = o3d.t.geometry.TriangleMesh.create_sphere(