diff options
Diffstat (limited to 'ext/pybind11/tests/test_numpy_vectorize.cpp')
-rw-r--r-- | ext/pybind11/tests/test_numpy_vectorize.cpp | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/ext/pybind11/tests/test_numpy_vectorize.cpp b/ext/pybind11/tests/test_numpy_vectorize.cpp index 6d94db2a1..8e951c6e1 100644 --- a/ext/pybind11/tests/test_numpy_vectorize.cpp +++ b/ext/pybind11/tests/test_numpy_vectorize.cpp @@ -38,4 +38,21 @@ test_initializer numpy_vectorize([](py::module &m) { m.def("selective_func", [](py::array_t<int, py::array::c_style>) { return "Int branch taken."; }); m.def("selective_func", [](py::array_t<float, py::array::c_style>) { return "Float branch taken."; }); m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { return "Complex float branch taken."; }); + + + // Internal optimization test for whether the input is trivially broadcastable: + py::enum_<py::detail::broadcast_trivial>(m, "trivial") + .value("f_trivial", py::detail::broadcast_trivial::f_trivial) + .value("c_trivial", py::detail::broadcast_trivial::c_trivial) + .value("non_trivial", py::detail::broadcast_trivial::non_trivial); + m.def("vectorized_is_trivial", []( + py::array_t<int, py::array::forcecast> arg1, + py::array_t<float, py::array::forcecast> arg2, + py::array_t<double, py::array::forcecast> arg3 + ) { + size_t ndim; + std::vector<size_t> shape; + std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }}; + return py::detail::broadcast(buffers, ndim, shape); + }); }); |