Skip to content

Commit

Permalink
Draft of SmallMatrix Python bindings.
Browse files Browse the repository at this point in the history
  • Loading branch information
cemitch99 committed Dec 6, 2024
1 parent 706e751 commit 38ddba5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ target_sources(pyImpactX
ReferenceParticle.cpp
transformation.cpp
WakeConvolution.cpp
SmallMatrix.cpp
)
72 changes: 72 additions & 0 deletions src/python/SmallMatrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2021-2023 The ImpactX Community
*
* Authors: Ryan Sandberg, Axel Huebl
* License: BSD-3-Clause-LBNL
*/
#include "pyImpactX.H"
#include <AMReX_SmallMatrix.H>

namespace py = pybind11;

namespace pybind11 {
namespace detail {

template <typename T, int NRows, int NCols, amrex::Order ORDER, int StartIndex>
struct pybind11::detail::type_caster<amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex>> {
public:
PYBIND11_TYPE_CASTER(amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex>,
_("SmallMatrix[") + py::detail::make_caster<T>::name() + _("]"));

// Conversion from Python to C++
bool load(handle src, bool) {
// Ensure we have a numpy array
py::array_t<T> arr = py::cast<py::array_t<T>>(src);
py::buffer_info buf = arr.request();

// Check dimensions and shape
if (buf.ndim != 2) {
throw std::runtime_error("SmallMatrix requires a 2D array.");
}
if (buf.shape[0] != NRows || buf.shape[1] != NCols) {
throw std::runtime_error("SmallMatrix array shape must match NRows x NCols.");
}

// Create a SmallMatrix and copy data
amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex> mat;
T* ptr = static_cast<T*>(buf.ptr);
for (int i = 0; i < NRows * NCols; ++i) {
mat.m_mat[i] = ptr[i];
}

value = mat;
return true;
}

// Conversion from C++ to Python
static handle cast(const amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex>& src,
return_value_policy /* policy */, handle /* parent */) {
py::array_t<T> arr({NRows, NCols});
py::buffer_info buf = arr.request();
T* ptr = static_cast<T*>(buf.ptr);
for (int i = 0; i < NRows * NCols; ++i) {
ptr[i] = src.m_mat[i];
}
return arr.release();
}
};

} // namespace detail
} // namespace pybind11


PYBIND11_MODULE(example, m) {
// You can now just bind constructors and methods normally without defining conversion code:
py::class_<amrex::SmallMatrix<double, 6, 6>>(m, "SmallMatrix6x6")
.def(py::init<>()) // Default init
.def("as_array", [](const amrex::SmallMatrix<double, 6, 6>& mat) {
return mat; // Will use type_caster to return a numpy array
});

// Now Python functions expecting a SmallMatrix<double,6,6> can pass a numpy array directly:
// def some_func(mat: SmallMatrix6x6): ...
}
1 change: 1 addition & 0 deletions src/python/pyImpactX.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>

#include <particles/elements/All.H>

Expand Down

0 comments on commit 38ddba5

Please sign in to comment.