From 7691ee9cf542a02fbdc777b9bab8db0b4d8c29f8 Mon Sep 17 00:00:00 2001 From: Marcin Wojdyr Date: Mon, 16 Sep 2024 20:46:10 +0200 Subject: [PATCH] simplify MaskedGrid::iterator and simplify and change python MaskedGrid.mask_array, now it's oriented in the same way as the grid --- docs/grid.rst | 2 +- include/gemmi/asumask.hpp | 32 +++++++++++--------------------- python/grid.cpp | 26 +++++++++++--------------- 3 files changed, 23 insertions(+), 37 deletions(-) diff --git a/docs/grid.rst b/docs/grid.rst index 1521247f..e3d0487d 100644 --- a/docs/grid.rst +++ b/docs/grid.rst @@ -452,7 +452,7 @@ The primary use for MaskedGrid is working with asymmetric unit (asu) only: >>> asu.mask_array # doctest: +ELLIPSIS array([[[0, 0, 0, ..., 1, 1, 1], ... - [1, 0, 0, ..., 1, 1, 1]]], dtype=int8) + [1, 1, 1, ..., 1, 1, 1]]], dtype=int8) >>> sum(point.value for point in asu) 7.125 >>> for point in asu: diff --git a/include/gemmi/asumask.hpp b/include/gemmi/asumask.hpp index c9e8ab9f..4d527267 100644 --- a/include/gemmi/asumask.hpp +++ b/include/gemmi/asumask.hpp @@ -220,32 +220,22 @@ template struct MaskedGrid { Grid* grid; struct iterator { - MaskedGrid& parent; - size_t index; - int u = 0, v = 0, w = 0; - iterator(MaskedGrid& parent_, size_t index_) - : parent(parent_), index(index_) {} + typename GridBase::iterator grid_iterator; + const std::vector& mask_ref; + iterator(typename GridBase::iterator it, const std::vector& mask) + : grid_iterator(it), mask_ref(mask) {} iterator& operator++() { do { - ++index; - if (++u == parent.grid->nu) { - u = 0; - if (++v == parent.grid->nv) { - v = 0; - ++w; - } - } - } while (index != parent.mask.size() && parent.mask[index] != 0); + ++grid_iterator; + } while (grid_iterator.index != mask_ref.size() && mask_ref[grid_iterator.index] != 0); return *this; } - typename GridBase::Point operator*() { - return {u, v, w, &parent.grid->data[index]}; - } - bool operator==(const iterator &o) const { return index == o.index; } - bool operator!=(const iterator &o) const { return index != o.index; } + typename GridBase::Point operator*() { return *grid_iterator; } + bool operator==(const iterator &o) const { return grid_iterator == o.grid_iterator; } + bool operator!=(const iterator &o) const { return grid_iterator != o.grid_iterator; } }; - iterator begin() { return {*this, 0}; } - iterator end() { return {*this, mask.size()}; } + iterator begin() { return {grid->begin(), mask}; } + iterator end() { return {grid->end(), mask}; } }; template diff --git a/python/grid.cpp b/python/grid.cpp index e3c85800..95129228 100644 --- a/python/grid.cpp +++ b/python/grid.cpp @@ -28,6 +28,15 @@ bool operator>(const std::complex& a, const std::complex& b) { using namespace gemmi; +template +auto grid_to_array(GridMeta& g, std::vector& data) { + // should we take AxisOrder into account here? + return nb::ndarray(data.data(), + {(size_t)g.nu, (size_t)g.nv, (size_t)g.nw}, + nb::handle(), + {1, g.nu, g.nu * g.nv}); +} + template nb::class_, GridMeta> add_grid_base(nb::module_& m, const char* name) { using GrBase = GridBase; @@ -46,14 +55,7 @@ nb::class_, GridMeta> add_grid_base(nb::module_& m, const char* name self.w, ") -> ", +*self.value, '>'); }); - auto to_array = [](GrBase& g) { - // should we take AxisOrder into account here? - return nb::ndarray(g.data.data(), - {(size_t)g.nu, (size_t)g.nv, (size_t)g.nw}, - nb::handle(), - {1, g.nu, g.nu * g.nv}); - }; - + auto to_array = [](GrBase& gr) { return grid_to_array(gr, gr.data); }; grid_base .def_prop_ro("array", to_array, nb::rv_policy::reference_internal) .def("__array__", to_array, nb::rv_policy::reference_internal) @@ -147,13 +149,7 @@ nb::class_, GridBase> add_grid_common(nb::module_& m, const std::stri masked_grid .def_ro("grid", &Masked::grid, nb::rv_policy::reference) .def_prop_ro("mask_array", [](Masked& self) { - const Gr& gr = *self.grid; - // cf. to_array() above - return nb::ndarray( - self.mask.data(), - {(size_t)gr.nu, (size_t)gr.nv, (size_t)gr.nw}, - nb::handle(), - {int64_t(gr.nv * gr.nw), int64_t(gr.nw), 1}); + return grid_to_array(*self.grid, self.mask); }, nb::rv_policy::reference_internal) .def("__iter__", [](Masked& self) { return usual_iterator(self, self);