diff --git a/python_bindings/src/halide/CMakeLists.txt b/python_bindings/src/halide/CMakeLists.txt index c0f48f569eb1..df3d8fe8bbfb 100644 --- a/python_bindings/src/halide/CMakeLists.txt +++ b/python_bindings/src/halide/CMakeLists.txt @@ -20,6 +20,7 @@ set(native_sources PyLoopLevel.cpp PyModule.cpp PyParam.cpp + PyParameter.cpp PyPipeline.cpp PyRDom.cpp PyStage.cpp diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index 5b0926f8c8ab..750ee6cc092f 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -18,6 +18,7 @@ #include "PyLambda.h" #include "PyModule.h" #include "PyParam.h" +#include "PyParameter.h" #include "PyPipeline.h" #include "PyRDom.h" #include "PyTarget.h" @@ -61,6 +62,7 @@ PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) { define_lambda(m); define_operators(m); define_param(m); + define_parameter(m); define_image_param(m); define_type(m); define_derivative(m); diff --git a/python_bindings/src/halide/halide_/PyParam.cpp b/python_bindings/src/halide/halide_/PyParam.cpp index 8cf0ae1b9945..ac5c13d45d35 100644 --- a/python_bindings/src/halide/halide_/PyParam.cpp +++ b/python_bindings/src/halide/halide_/PyParam.cpp @@ -38,29 +38,6 @@ void add_param_methods(py::class_> ¶m_class) { } // namespace void define_param(py::module &m) { - // This is a "just-enough" wrapper around Parameter to let us pass it back - // and forth between Py and C++. It deliberately exposes very few methods, - // and we should keep it that way. - auto parameter_class = - py::class_(m, "InternalParameter") - .def(py::init(), py::arg("p")) - .def("defined", &Parameter::defined) - .def("type", &Parameter::type) - .def("dimensions", &Parameter::dimensions) - .def("_to_argument", [](const Parameter &p) -> Argument { - return Argument(p.name(), - p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar, - p.type(), - p.dimensions(), - p.get_argument_estimates()); - }) - .def("__repr__", [](const Parameter &p) -> std::string { - std::ostringstream o; - // Don't leak any info but the name into the repr string. - o << ""; - return o.str(); - }); - auto param_class = py::class_>(m, "Param") .def(py::init(), py::arg("type")) diff --git a/python_bindings/src/halide/halide_/PyParameter.cpp b/python_bindings/src/halide/halide_/PyParameter.cpp new file mode 100644 index 000000000000..50dac943b7b7 --- /dev/null +++ b/python_bindings/src/halide/halide_/PyParameter.cpp @@ -0,0 +1,107 @@ +#include "PyParameter.h" + +#include "PyType.h" + +namespace Halide { +namespace PythonBindings { + +namespace { + +template +void add_scalar_methods(py::class_ ¶meter_class) { + parameter_class + .def("scalar", &Parameter::scalar) + .def( + "set_scalar", [](Parameter ¶meter, TYPE value) -> void { + parameter.set_scalar(value); + }, + py::arg("value")); +} + +} // namespace + +void define_parameter(py::module &m) { + + // Disambiguate some ambigious methods + void (Parameter::*set_scalar_method)(const Type &t, halide_scalar_value_t val) = &Parameter::set_scalar; + + auto parameter_class = + py::class_(m, "Parameter") + .def(py::init<>()) + .def(py::init(), py::arg("p")) + .def(py::init()) + .def(py::init()) + .def(py::init &, int, const std::vector &, + MemoryType>()) + .def(py::init()) + .def("_to_argument", [](const Parameter &p) -> Argument { + return Argument(p.name(), + p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar, + p.type(), + p.dimensions(), + p.get_argument_estimates()); + }) + .def("__repr__", [](const Parameter &p) -> std::string { + std::ostringstream o; + o << ""; + return o.str(); + }) + .def("type", &Parameter::type) + .def("dimensions", &Parameter::dimensions) + .def("name", &Parameter::name) + .def("is_buffer", &Parameter::is_buffer) + .def("scalar_expr", &Parameter::scalar_expr) + .def("set_scalar", set_scalar_method, py::arg("value_type"), py::arg("value")) + .def("buffer", &Parameter::buffer) + .def("set_buffer", &Parameter::set_buffer, py::arg("buffer")) + .def("same_as", &Parameter::same_as, py::arg("other")) + .def("defined", &Parameter::defined) + .def("set_min_constraint", &Parameter::set_min_constraint, py::arg("dim"), py::arg("expr")) + .def("set_extent_constraint", &Parameter::set_extent_constraint, py::arg("dim"), py::arg("expr")) + .def("set_stride_constraint", &Parameter::set_stride_constraint, py::arg("dim"), py::arg("expr")) + .def("set_min_constraint_estimate", &Parameter::set_min_constraint_estimate, py::arg("dim"), py::arg("expr")) + .def("set_extent_constraint_estimate", &Parameter::set_extent_constraint_estimate, py::arg("dim"), py::arg("expr")) + .def("set_host_alignment", &Parameter::set_host_alignment, py::arg("bytes")) + .def("min_constraint", &Parameter::min_constraint, py::arg("dim")) + .def("extent_constraint", &Parameter::extent_constraint, py::arg("dim")) + .def("stride_constraint", &Parameter::stride_constraint, py::arg("dim")) + .def("min_constraint_estimate", &Parameter::min_constraint_estimate, py::arg("dim")) + .def("extent_constraint_estimate", &Parameter::extent_constraint_estimate, py::arg("dim")) + .def("host_alignment", &Parameter::host_alignment) + .def("buffer_constraints", &Parameter::buffer_constraints) + .def("set_min_value", &Parameter::set_min_value, py::arg("expr")) + .def("min_value", &Parameter::min_value) + .def("set_max_value", &Parameter::set_max_value, py::arg("expr")) + .def("max_value", &Parameter::max_value) + .def("set_estimate", &Parameter::set_estimate, py::arg("expr")) + .def("estimate", &Parameter::estimate) + .def("set_default_value", &Parameter::set_default_value, py::arg("expr")) + .def("default_value", &Parameter::default_value) + .def("get_argument_estimates", &Parameter::get_argument_estimates) + .def("store_in", &Parameter::store_in, py::arg("memory_type")) + .def("memory_type", &Parameter::memory_type); + + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); + add_scalar_methods(parameter_class); +} + +} // namespace PythonBindings +} // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyParameter.h b/python_bindings/src/halide/halide_/PyParameter.h new file mode 100644 index 000000000000..49f84743a9d9 --- /dev/null +++ b/python_bindings/src/halide/halide_/PyParameter.h @@ -0,0 +1,14 @@ +#ifndef HALIDE_PYTHON_BINDINGS_PYPARAMETER_H +#define HALIDE_PYTHON_BINDINGS_PYPARAMETER_H + +#include "PyHalide.h" + +namespace Halide { +namespace PythonBindings { + +void define_parameter(py::module &m); + +} // namespace PythonBindings +} // namespace Halide + +#endif // HALIDE_PYTHON_BINDINGS_PYPARAMETER_H