Skip to content

Commit

Permalink
simplify MaskedGrid::iterator
Browse files Browse the repository at this point in the history
and simplify and change python MaskedGrid.mask_array,
now it's oriented in the same way as the grid
  • Loading branch information
wojdyr committed Sep 17, 2024
1 parent bd81ea0 commit 7691ee9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 37 deletions.
2 changes: 1 addition & 1 deletion docs/grid.rst
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ The primary use for MaskedGrid is working with asymmetric unit (asu) only:
>>> asu.mask_array # doctest: +ELLIPSIS
array([[[0, 0, 0, ..., 1, 1, 1],
...
[1, 0, 0, ..., 1, 1, 1]]], dtype=int8)
[1, 1, 1, ..., 1, 1, 1]]], dtype=int8)
>>> sum(point.value for point in asu)
7.125
>>> for point in asu:
Expand Down
32 changes: 11 additions & 21 deletions include/gemmi/asumask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,32 +220,22 @@ template<typename T, typename V=std::int8_t> struct MaskedGrid {
Grid<T>* grid;

struct iterator {
MaskedGrid& parent;
size_t index;
int u = 0, v = 0, w = 0;
iterator(MaskedGrid& parent_, size_t index_)
: parent(parent_), index(index_) {}
typename GridBase<T>::iterator grid_iterator;
const std::vector<V>& mask_ref;
iterator(typename GridBase<T>::iterator it, const std::vector<V>& mask)
: grid_iterator(it), mask_ref(mask) {}
iterator& operator++() {
do {
++index;
if (++u == parent.grid->nu) {
u = 0;
if (++v == parent.grid->nv) {
v = 0;
++w;
}
}
} while (index != parent.mask.size() && parent.mask[index] != 0);
++grid_iterator;
} while (grid_iterator.index != mask_ref.size() && mask_ref[grid_iterator.index] != 0);
return *this;
}
typename GridBase<T>::Point operator*() {
return {u, v, w, &parent.grid->data[index]};
}
bool operator==(const iterator &o) const { return index == o.index; }
bool operator!=(const iterator &o) const { return index != o.index; }
typename GridBase<T>::Point operator*() { return *grid_iterator; }
bool operator==(const iterator &o) const { return grid_iterator == o.grid_iterator; }
bool operator!=(const iterator &o) const { return grid_iterator != o.grid_iterator; }
};
iterator begin() { return {*this, 0}; }
iterator end() { return {*this, mask.size()}; }
iterator begin() { return {grid->begin(), mask}; }
iterator end() { return {grid->end(), mask}; }
};

template<typename V=std::int8_t>
Expand Down
26 changes: 11 additions & 15 deletions python/grid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ bool operator>(const std::complex<float>& a, const std::complex<float>& b) {

using namespace gemmi;

template<typename T>
auto grid_to_array(GridMeta& g, std::vector<T>& data) {
// should we take AxisOrder into account here?
return nb::ndarray<nb::numpy, T>(data.data(),
{(size_t)g.nu, (size_t)g.nv, (size_t)g.nw},
nb::handle(),
{1, g.nu, g.nu * g.nv});
}

template<typename T>
nb::class_<GridBase<T>, GridMeta> add_grid_base(nb::module_& m, const char* name) {
using GrBase = GridBase<T>;
Expand All @@ -46,14 +55,7 @@ nb::class_<GridBase<T>, GridMeta> add_grid_base(nb::module_& m, const char* name
self.w, ") -> ", +*self.value, '>');
});

auto to_array = [](GrBase& g) {
// should we take AxisOrder into account here?
return nb::ndarray<nb::numpy, T>(g.data.data(),
{(size_t)g.nu, (size_t)g.nv, (size_t)g.nw},
nb::handle(),
{1, g.nu, g.nu * g.nv});
};

auto to_array = [](GrBase& gr) { return grid_to_array(gr, gr.data); };
grid_base
.def_prop_ro("array", to_array, nb::rv_policy::reference_internal)
.def("__array__", to_array, nb::rv_policy::reference_internal)
Expand Down Expand Up @@ -147,13 +149,7 @@ nb::class_<Grid<T>, GridBase<T>> add_grid_common(nb::module_& m, const std::stri
masked_grid
.def_ro("grid", &Masked::grid, nb::rv_policy::reference)
.def_prop_ro("mask_array", [](Masked& self) {
const Gr& gr = *self.grid;
// cf. to_array() above
return nb::ndarray<nb::numpy, std::int8_t>(
self.mask.data(),
{(size_t)gr.nu, (size_t)gr.nv, (size_t)gr.nw},
nb::handle(),
{int64_t(gr.nv * gr.nw), int64_t(gr.nw), 1});
return grid_to_array(*self.grid, self.mask);
}, nb::rv_policy::reference_internal)
.def("__iter__", [](Masked& self) {
return usual_iterator(self, self);
Expand Down

0 comments on commit 7691ee9

Please sign in to comment.