summaryrefslogtreecommitdiff
path: root/ext/pybind11/tests/test_numpy_vectorize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ext/pybind11/tests/test_numpy_vectorize.cpp')
-rw-r--r--ext/pybind11/tests/test_numpy_vectorize.cpp49
1 files changed, 40 insertions, 9 deletions
diff --git a/ext/pybind11/tests/test_numpy_vectorize.cpp b/ext/pybind11/tests/test_numpy_vectorize.cpp
index 8e951c6e1..a875a74b9 100644
--- a/ext/pybind11/tests/test_numpy_vectorize.cpp
+++ b/ext/pybind11/tests/test_numpy_vectorize.cpp
@@ -16,11 +16,11 @@ double my_func(int x, float y, double z) {
return (float) x*y*z;
}
-std::complex<double> my_func3(std::complex<double> c) {
- return c * std::complex<double>(2.f);
-}
+TEST_SUBMODULE(numpy_vectorize, m) {
+ try { py::module::import("numpy"); }
+ catch (...) { return; }
-test_initializer numpy_vectorize([](py::module &m) {
+ // test_vectorize, test_docs, test_array_collapse
// Vectorize all arguments of a function (though non-vector arguments are also allowed)
m.def("vectorized_func", py::vectorize(my_func));
@@ -32,14 +32,45 @@ test_initializer numpy_vectorize([](py::module &m) {
);
// Vectorize a complex-valued function
- m.def("vectorized_func3", py::vectorize(my_func3));
+ m.def("vectorized_func3", py::vectorize(
+ [](std::complex<double> c) { return c * std::complex<double>(2.f); }
+ ));
- /// Numpy function which only accepts specific data types
+ // test_type_selection
+ // Numpy function which only accepts specific data types
m.def("selective_func", [](py::array_t<int, py::array::c_style>) { return "Int branch taken."; });
m.def("selective_func", [](py::array_t<float, py::array::c_style>) { return "Float branch taken."; });
m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { return "Complex float branch taken."; });
+ // test_passthrough_arguments
+ // Passthrough test: references and non-pod types should be automatically passed through (in the
+ // function definition below, only `b`, `d`, and `g` are vectorized):
+ struct NonPODClass {
+ NonPODClass(int v) : value{v} {}
+ int value;
+ };
+ py::class_<NonPODClass>(m, "NonPODClass").def(py::init<int>());
+ m.def("vec_passthrough", py::vectorize(
+ [](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) {
+ return *a + b + c.at(0) + d + e + f.value + g;
+ }
+ ));
+
+ // test_method_vectorization
+ struct VectorizeTestClass {
+ VectorizeTestClass(int v) : value{v} {};
+ float method(int x, float y) { return y + (float) (x + value); }
+ int value = 0;
+ };
+ py::class_<VectorizeTestClass> vtc(m, "VectorizeTestClass");
+ vtc .def(py::init<int>())
+ .def_readwrite("value", &VectorizeTestClass::value);
+
+ // Automatic vectorizing of methods
+ vtc.def("method", py::vectorize(&VectorizeTestClass::method));
+
+ // test_trivial_broadcasting
// Internal optimization test for whether the input is trivially broadcastable:
py::enum_<py::detail::broadcast_trivial>(m, "trivial")
.value("f_trivial", py::detail::broadcast_trivial::f_trivial)
@@ -50,9 +81,9 @@ test_initializer numpy_vectorize([](py::module &m) {
py::array_t<float, py::array::forcecast> arg2,
py::array_t<double, py::array::forcecast> arg3
) {
- size_t ndim;
- std::vector<size_t> shape;
+ ssize_t ndim;
+ std::vector<ssize_t> shape;
std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
return py::detail::broadcast(buffers, ndim, shape);
});
-});
+}