diff options
Diffstat (limited to 'ext/pybind11/include/pybind11/numpy.h')
-rw-r--r-- | ext/pybind11/include/pybind11/numpy.h | 462 |
1 files changed, 343 insertions, 119 deletions
diff --git a/ext/pybind11/include/pybind11/numpy.h b/ext/pybind11/include/pybind11/numpy.h index e6f4efdf9..3227a12eb 100644 --- a/ext/pybind11/include/pybind11/numpy.h +++ b/ext/pybind11/include/pybind11/numpy.h @@ -35,9 +35,11 @@ static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); NAMESPACE_BEGIN(pybind11) + +class array; // Forward declaration + NAMESPACE_BEGIN(detail) -template <typename type, typename SFINAE = void> struct npy_format_descriptor { }; -template <typename type> struct is_pod_struct; +template <typename type, typename SFINAE = void> struct npy_format_descriptor; struct PyArrayDescr_Proxy { PyObject_HEAD @@ -108,11 +110,11 @@ inline numpy_internals& get_numpy_internals() { struct npy_api { enum constants { - NPY_C_CONTIGUOUS_ = 0x0001, - NPY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, + NPY_ARRAY_F_CONTIGUOUS_ = 0x0002, NPY_ARRAY_OWNDATA_ = 0x0004, NPY_ARRAY_FORCECAST_ = 0x0010, - NPY_ENSURE_ARRAY_ = 0x0040, + NPY_ARRAY_ENSUREARRAY_ = 0x0040, NPY_ARRAY_ALIGNED_ = 0x0100, NPY_ARRAY_WRITEABLE_ = 0x0400, NPY_BOOL_ = 0, @@ -155,6 +157,7 @@ struct npy_api { int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, Py_ssize_t *, PyObject **, PyObject *); PyObject *(*PyArray_Squeeze_)(PyObject *); + int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); private: enum functions { API_PyArray_Type = 2, @@ -169,7 +172,8 @@ private: API_PyArray_DescrConverter = 174, API_PyArray_EquivTypes = 182, API_PyArray_GetArrayParamsFromObject = 278, - API_PyArray_Squeeze = 136 + API_PyArray_Squeeze = 136, + API_PyArray_SetBaseObject = 282 }; static npy_api lookup() { @@ -195,6 +199,7 @@ private: DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_GetArrayParamsFromObject); DECL_NPY_API(PyArray_Squeeze); + DECL_NPY_API(PyArray_SetBaseObject); #undef DECL_NPY_API return api; } @@ -220,6 +225,128 @@ inline bool check_flags(const void* ptr, int flag) { return (flag == (array_proxy(ptr)->flags & flag)); } +template <typename T> struct is_std_array : std::false_type { }; +template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { }; +template <typename T> struct is_complex : std::false_type { }; +template <typename T> struct is_complex<std::complex<T>> : std::true_type { }; + +template <typename T> using is_pod_struct = all_of< + std::is_pod<T>, // since we're accessing directly in memory we need a POD type + satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum> +>; + +template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; } +template <size_t Dim = 0, typename Strides, typename... Ix> +size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) { + return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...); +} + +/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through + * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims` + * will be -1 for dimensions determined at runtime. + */ +template <typename T, ssize_t Dims> +class unchecked_reference { +protected: + static constexpr bool Dynamic = Dims < 0; + const unsigned char *data_; + // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to + // make large performance gains on big, nested loops, but requires compile-time dimensions + conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>> + shape_, strides_; + const size_t dims_; + + friend class pybind11::array; + // Constructor for compile-time dimensions: + template <bool Dyn = Dynamic> + unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>) + : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} { + for (size_t i = 0; i < dims_; i++) { + shape_[i] = shape[i]; + strides_[i] = strides[i]; + } + } + // Constructor for runtime dimensions: + template <bool Dyn = Dynamic> + unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims) + : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {} + +public: + /** Unchecked const reference access to data at the given indices. For a compile-time known + * number of dimensions, this requires the correct number of arguments; for run-time + * dimensionality, this is not checked (and so is up to the caller to use safely). + */ + template <typename... Ix> const T &operator()(Ix... index) const { + static_assert(sizeof...(Ix) == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...)); + } + /** Unchecked const reference access to data; this operator only participates if the reference + * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. + */ + template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>> + const T &operator[](size_t index) const { return operator()(index); } + + /// Pointer access to the data at the given indices. + template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); } + + /// Returns the item size, i.e. sizeof(T) + constexpr static size_t itemsize() { return sizeof(T); } + + /// Returns the shape (i.e. size) of dimension `dim` + size_t shape(size_t dim) const { return shape_[dim]; } + + /// Returns the number of dimensions of the array + size_t ndim() const { return dims_; } + + /// Returns the total number of elements in the referenced array, i.e. the product of the shapes + template <bool Dyn = Dynamic> + enable_if_t<!Dyn, size_t> size() const { + return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>()); + } + template <bool Dyn = Dynamic> + enable_if_t<Dyn, size_t> size() const { + return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>()); + } + + /// Returns the total number of bytes used by the referenced data. Note that the actual span in + /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice). + size_t nbytes() const { + return size() * itemsize(); + } +}; + +template <typename T, ssize_t Dims> +class unchecked_mutable_reference : public unchecked_reference<T, Dims> { + friend class pybind11::array; + using ConstBase = unchecked_reference<T, Dims>; + using ConstBase::ConstBase; + using ConstBase::Dynamic; +public: + /// Mutable, unchecked access to data at the given indices. + template <typename... Ix> T& operator()(Ix... index) { + static_assert(sizeof...(Ix) == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return const_cast<T &>(ConstBase::operator()(index...)); + } + /** Mutable, unchecked access data at the given index; this operator only participates if the + * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is + * exactly equivalent to `obj(index)`. + */ + template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>> + T &operator[](size_t index) { return operator()(index); } + + /// Mutable pointer access to the data at the given indices. + template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); } +}; + +template <typename T, size_t Dim> +struct type_caster<unchecked_reference<T, Dim>> { + static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable"); +}; +template <typename T, size_t Dim> +struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {}; + NAMESPACE_END(detail) class dtype : public object { @@ -321,8 +448,8 @@ public: PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) enum { - c_style = detail::npy_api::NPY_C_CONTIGUOUS_, - f_style = detail::npy_api::NPY_F_CONTIGUOUS_, + c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_, + f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_, forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; @@ -340,7 +467,7 @@ public: int flags = 0; if (base && ptr) { if (isinstance<array>(base)) - /* Copy flags from base (except baseship bit) */ + /* Copy flags from base (except ownership bit) */ flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; else /* Writable by default, easy to downgrade later on if needed */ @@ -348,13 +475,15 @@ public: } auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), - (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr)); + api.PyArray_Type_, descr.release().ptr(), (int) ndim, + reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())), + reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())), + const_cast<void *>(ptr), flags, nullptr)); if (!tmp) pybind11_fail("NumPy: unable to create array!"); if (ptr) { if (base) { - detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr(); + api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr()); } else { tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); } @@ -373,7 +502,7 @@ public: template<typename T> array(const std::vector<size_t>& shape, const std::vector<size_t>& strides, const T* ptr, handle base = handle()) - : array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { } + : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { } template <typename T> array(const std::vector<size_t> &shape, const T *ptr, @@ -486,6 +615,31 @@ public: return offset_at(index...) / itemsize(); } + /** Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() { + if (Dims >= 0 && ndim() != (size_t) Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim()); + } + + /** Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the + * underlying array have the `writable` flag. Use with care: the array must not be destroyed or + * reshaped for the duration of the returned object, and the caller must take care not to access + * invalid dimensions or dimension indices. + */ + template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const { + if (Dims >= 0 && ndim() != (size_t) Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim()); + } + /// Return a new view with all of the dimensions of length 1 removed array squeeze() { auto& api = detail::npy_api::get(); @@ -511,18 +665,12 @@ protected: template<typename... Ix> size_t byte_offset(Ix... index) const { check_dimensions(index...); - return byte_offset_unsafe(index...); - } - - template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const { - return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...); + return detail::byte_offset_unsafe(strides(), size_t(index)...); } - template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; } - void check_writeable() const { if (!writeable()) - throw std::runtime_error("array is not writeable"); + throw std::domain_error("array is not writeable"); } static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) { @@ -557,12 +705,14 @@ protected: if (ptr == nullptr) return nullptr; return detail::npy_api::get().PyArray_FromAny_( - ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); } }; template <typename T, int ExtraFlags = array::forcecast> class array_t : public array { public: + using value_type = T; + array_t() : array(0, static_cast<const T *>(nullptr)) {} array_t(handle h, borrowed_t) : array(h, borrowed) { } array_t(handle h, stolen_t) : array(h, stolen) { } @@ -621,8 +771,27 @@ public: return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); } - /// Ensure that the argument is a NumPy array of the correct dtype. - /// In case of an error, nullptr is returned and the Python error is cleared. + /** Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() { + return array::mutable_unchecked<T, Dims>(); + } + + /** Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying + * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped + * for the duration of the returned object, and the caller must take care not to access invalid + * dimensions or dimension indices. + */ + template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const { + return array::unchecked<T, Dims>(); + } + + /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert + /// it). In case of an error, nullptr is returned and the Python error is cleared. static array_t ensure(handle h) { auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr())); if (!result) @@ -630,7 +799,7 @@ public: return result; } - static bool _check(handle h) { + static bool check_(handle h) { const auto &api = detail::npy_api::get(); return api.PyArray_Check_(h.ptr()) && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr()); @@ -643,7 +812,7 @@ protected: return nullptr; return detail::npy_api::get().PyArray_FromAny_( ptr, dtype::of<T>().release().ptr(), 0, 0, - detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); } }; @@ -674,7 +843,9 @@ template <typename T, int ExtraFlags> struct pyobject_caster<array_t<T, ExtraFlags>> { using type = array_t<T, ExtraFlags>; - bool load(handle src, bool /* convert */) { + bool load(handle src, bool convert) { + if (!convert && !type::check_(src)) + return false; value = type::ensure(src); return static_cast<bool>(value); } @@ -685,65 +856,55 @@ struct pyobject_caster<array_t<T, ExtraFlags>> { PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name()); }; -template <typename T> struct is_std_array : std::false_type { }; -template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { }; - template <typename T> -struct is_pod_struct { - enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types - !std::is_reference<T>::value && - !std::is_array<T>::value && - !is_std_array<T>::value && - !std::is_integral<T>::value && - !std::is_enum<T>::value && - !std::is_same<typename std::remove_cv<T>::type, float>::value && - !std::is_same<typename std::remove_cv<T>::type, double>::value && - !std::is_same<typename std::remove_cv<T>::type, bool>::value && - !std::is_same<typename std::remove_cv<T>::type, std::complex<float>>::value && - !std::is_same<typename std::remove_cv<T>::type, std::complex<double>>::value }; +struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> { + static bool compare(const buffer_info& b) { + return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr()); + } }; -template <typename T> struct npy_format_descriptor<T, enable_if_t<std::is_integral<T>::value>> { +template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> { private: - constexpr static const int values[8] = { - npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, - npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ }; + // NB: the order here must match the one in common.h + constexpr static const int values[15] = { + npy_api::NPY_BOOL_, + npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, + npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_, + npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_, + npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_ + }; + public: - enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] }; + static constexpr int value = values[detail::is_fmt_numeric<T>::index]; + static pybind11::dtype dtype() { if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) return reinterpret_borrow<pybind11::dtype>(ptr); pybind11_fail("Unsupported buffer format!"); } - template <typename T2 = T, enable_if_t<std::is_signed<T2>::value, int> = 0> - static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); } - template <typename T2 = T, enable_if_t<!std::is_signed<T2>::value, int> = 0> - static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); } + template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0> + static PYBIND11_DESCR name() { + return _<std::is_same<T, bool>::value>(_("bool"), + _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>()); + } + template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0> + static PYBIND11_DESCR name() { + return _<std::is_same<T, float>::value || std::is_same<T, double>::value>( + _("float") + _<sizeof(T)*8>(), _("longdouble")); + } + template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0> + static PYBIND11_DESCR name() { + return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>( + _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex")); + } }; -template <typename T> constexpr const int npy_format_descriptor< - T, enable_if_t<std::is_integral<T>::value>>::values[8]; - -#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \ - enum { value = npy_api::NumPyName }; \ - static pybind11::dtype dtype() { \ - if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \ - return reinterpret_borrow<pybind11::dtype>(ptr); \ - pybind11_fail("Unsupported buffer format!"); \ - } \ - static PYBIND11_DESCR name() { return _(Name); } } -DECL_FMT(float, NPY_FLOAT_, "float32"); -DECL_FMT(double, NPY_DOUBLE_, "float64"); -DECL_FMT(bool, NPY_BOOL_, "bool"); -DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64"); -DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128"); -#undef DECL_FMT - -#define DECL_CHAR_FMT \ + +#define PYBIND11_DECL_CHAR_FMT \ static PYBIND11_DESCR name() { return _("S") + _<N>(); } \ static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); } -template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT }; -template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT }; -#undef DECL_CHAR_FMT +template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT }; +template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT }; +#undef PYBIND11_DECL_CHAR_FMT template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> { private: @@ -798,9 +959,9 @@ inline PYBIND11_NOINLINE void register_structured_dtype( for (auto& field : ordered_fields) { if (field.offset > offset) oss << (field.offset - offset) << 'x'; - // mark unaligned fields with '=' + // mark unaligned fields with '^' (unaligned native type) if (field.offset % field.alignment) - oss << '='; + oss << '^'; oss << field.format << ':' << field.name << ':'; offset = field.offset + field.size; } @@ -820,9 +981,10 @@ inline PYBIND11_NOINLINE void register_structured_dtype( get_internals().direct_conversions[tindex].push_back(direct_converter); } -template <typename T> -struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> { - static PYBIND11_DESCR name() { return _("struct"); } +template <typename T, typename SFINAE> struct npy_format_descriptor { + static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype"); + + static PYBIND11_DESCR name() { return make_caster<T>::name(); } static pybind11::dtype dtype() { return reinterpret_borrow<pybind11::dtype>(dtype_ptr()); @@ -1043,87 +1205,146 @@ private: std::array<common_iter, N> m_common_iterator; }; +enum class broadcast_trivial { non_trivial, c_trivial, f_trivial }; + +// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial +// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a +// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage +// buffer; returns `non_trivial` otherwise. template <size_t N> -bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) { +broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, size_t &ndim, std::vector<size_t> &shape) { ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) { return std::max(res, buf.ndim); }); - shape = std::vector<size_t>(ndim, 1); - bool trivial_broadcast = true; + shape.clear(); + shape.resize(ndim, 1); + + // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or + // the full size). for (size_t i = 0; i < N; ++i) { auto res_iter = shape.rbegin(); - bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim); - for (auto shape_iter = buffers[i].shape.rbegin(); - shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) { - - if (*res_iter == 1) - *res_iter = *shape_iter; - else if ((*shape_iter != 1) && (*res_iter != *shape_iter)) + auto end = buffers[i].shape.rend(); + for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) { + const auto &dim_size_in = *shape_iter; + auto &dim_size_out = *res_iter; + + // Each input dimension can either be 1 or `n`, but `n` values must match across buffers + if (dim_size_out == 1) + dim_size_out = dim_size_in; + else if (dim_size_in != 1 && dim_size_in != dim_size_out) pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); + } + } - i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter); + bool trivial_broadcast_c = true; + bool trivial_broadcast_f = true; + for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) { + if (buffers[i].size == 1) + continue; + + // Require the same number of dimensions: + if (buffers[i].ndim != ndim) + return broadcast_trivial::non_trivial; + + // Require all dimensions be full-size: + if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) + return broadcast_trivial::non_trivial; + + // Check for C contiguity (but only if previous inputs were also C contiguous) + if (trivial_broadcast_c) { + size_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.crend(); + for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin(); + trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_c = false; + } + } + + // Check for Fortran contiguity (if previous inputs were also F contiguous) + if (trivial_broadcast_f) { + size_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.cend(); + for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin(); + trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_f = false; + } } - trivial_broadcast = trivial_broadcast && i_trivial_broadcast; } - return trivial_broadcast; + + return + trivial_broadcast_c ? broadcast_trivial::c_trivial : + trivial_broadcast_f ? broadcast_trivial::f_trivial : + broadcast_trivial::non_trivial; } template <typename Func, typename Return, typename... Args> struct vectorize_helper { typename std::remove_reference<Func>::type f; + static constexpr size_t N = sizeof...(Args); template <typename T> explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { } - object operator()(array_t<Args, array::c_style | array::forcecast>... args) { - return run(args..., make_index_sequence<sizeof...(Args)>()); + object operator()(array_t<Args, array::forcecast>... args) { + return run(args..., make_index_sequence<N>()); } - template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) { + template <size_t ... Index> object run(array_t<Args, array::forcecast>&... args, index_sequence<Index...> index) { /* Request buffers from all parameters */ - const size_t N = sizeof...(Args); - std::array<buffer_info, N> buffers {{ args.request()... }}; /* Determine dimensions parameters of output array */ size_t ndim = 0; std::vector<size_t> shape(0); - bool trivial_broadcast = broadcast(buffers, ndim, shape); + auto trivial = broadcast(buffers, ndim, shape); size_t size = 1; std::vector<size_t> strides(ndim); if (ndim > 0) { - strides[ndim-1] = sizeof(Return); - for (size_t i = ndim - 1; i > 0; --i) { - strides[i - 1] = strides[i] * shape[i]; - size *= shape[i]; + if (trivial == broadcast_trivial::f_trivial) { + strides[0] = sizeof(Return); + for (size_t i = 1; i < ndim; ++i) { + strides[i] = strides[i - 1] * shape[i - 1]; + size *= shape[i - 1]; + } + size *= shape[ndim - 1]; + } + else { + strides[ndim-1] = sizeof(Return); + for (size_t i = ndim - 1; i > 0; --i) { + strides[i - 1] = strides[i] * shape[i]; + size *= shape[i]; + } + size *= shape[0]; } - size *= shape[0]; } if (size == 1) - return cast(f(*((Args *) buffers[Index].ptr)...)); + return cast(f(*reinterpret_cast<Args *>(buffers[Index].ptr)...)); array_t<Return> result(shape, strides); auto buf = result.request(); auto output = (Return *) buf.ptr; - if (trivial_broadcast) { - /* Call the function */ - for (size_t i = 0; i < size; ++i) { - output[i] = f((buffers[Index].size == 1 - ? *((Args *) buffers[Index].ptr) - : ((Args *) buffers[Index].ptr)[i])...); - } + /* Call the function */ + if (trivial == broadcast_trivial::non_trivial) { + apply_broadcast<Index...>(buffers, buf, index); } else { - apply_broadcast<N, Index...>(buffers, buf, index); + for (size_t i = 0; i < size; ++i) + output[i] = f((reinterpret_cast<Args *>(buffers[Index].ptr)[buffers[Index].size == 1 ? 0 : i])...); } return result; } - template <size_t N, size_t... Index> + template <size_t... Index> void apply_broadcast(const std::array<buffer_info, N> &buffers, buffer_info &output, index_sequence<Index...>) { using input_iterator = multi_array_iterator<N>; @@ -1140,26 +1361,29 @@ struct vectorize_helper { }; template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> { - static PYBIND11_DESCR name() { return _("numpy.ndarray[") + type_caster<T>::name() + _("]"); } + static PYBIND11_DESCR name() { + return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]"); + } }; NAMESPACE_END(detail) template <typename Func, typename Return, typename... Args> -detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) { +detail::vectorize_helper<Func, Return, Args...> +vectorize(const Func &f, Return (*) (Args ...)) { return detail::vectorize_helper<Func, Return, Args...>(f); } template <typename Return, typename... Args> -detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) { +detail::vectorize_helper<Return (*) (Args ...), Return, Args...> +vectorize(Return (*f) (Args ...)) { return vectorize<Return (*) (Args ...), Return, Args...>(f, f); } -template <typename Func> +template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type> auto vectorize(Func &&f) -> decltype( - vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type *) nullptr)) { - return vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype( - &std::remove_reference<Func>::type::operator())>::type *) nullptr); + vectorize(std::forward<Func>(f), (FuncType *) nullptr)) { + return vectorize(std::forward<Func>(f), (FuncType *) nullptr); } NAMESPACE_END(pybind11) |