summaryrefslogtreecommitdiff
path: root/ext/pybind11/tests/test_numpy_vectorize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ext/pybind11/tests/test_numpy_vectorize.cpp')
-rw-r--r--ext/pybind11/tests/test_numpy_vectorize.cpp17
1 files changed, 17 insertions, 0 deletions
diff --git a/ext/pybind11/tests/test_numpy_vectorize.cpp b/ext/pybind11/tests/test_numpy_vectorize.cpp
index 6d94db2a1..8e951c6e1 100644
--- a/ext/pybind11/tests/test_numpy_vectorize.cpp
+++ b/ext/pybind11/tests/test_numpy_vectorize.cpp
@@ -38,4 +38,21 @@ test_initializer numpy_vectorize([](py::module &m) {
m.def("selective_func", [](py::array_t<int, py::array::c_style>) { return "Int branch taken."; });
m.def("selective_func", [](py::array_t<float, py::array::c_style>) { return "Float branch taken."; });
m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { return "Complex float branch taken."; });
+
+
+ // Internal optimization test for whether the input is trivially broadcastable:
+ py::enum_<py::detail::broadcast_trivial>(m, "trivial")
+ .value("f_trivial", py::detail::broadcast_trivial::f_trivial)
+ .value("c_trivial", py::detail::broadcast_trivial::c_trivial)
+ .value("non_trivial", py::detail::broadcast_trivial::non_trivial);
+ m.def("vectorized_is_trivial", [](
+ py::array_t<int, py::array::forcecast> arg1,
+ py::array_t<float, py::array::forcecast> arg2,
+ py::array_t<double, py::array::forcecast> arg3
+ ) {
+ size_t ndim;
+ std::vector<size_t> shape;
+ std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
+ return py::detail::broadcast(buffers, ndim, shape);
+ });
});