summaryrefslogtreecommitdiff
path: root/ext/pybind11/include/pybind11/numpy.h
diff options
context:
space:
mode:
Diffstat (limited to 'ext/pybind11/include/pybind11/numpy.h')
-rw-r--r--ext/pybind11/include/pybind11/numpy.h462
1 files changed, 343 insertions, 119 deletions
diff --git a/ext/pybind11/include/pybind11/numpy.h b/ext/pybind11/include/pybind11/numpy.h
index e6f4efdf9..3227a12eb 100644
--- a/ext/pybind11/include/pybind11/numpy.h
+++ b/ext/pybind11/include/pybind11/numpy.h
@@ -35,9 +35,11 @@
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
NAMESPACE_BEGIN(pybind11)
+
+class array; // Forward declaration
+
NAMESPACE_BEGIN(detail)
-template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
-template <typename type> struct is_pod_struct;
+template <typename type, typename SFINAE = void> struct npy_format_descriptor;
struct PyArrayDescr_Proxy {
PyObject_HEAD
@@ -108,11 +110,11 @@ inline numpy_internals& get_numpy_internals() {
struct npy_api {
enum constants {
- NPY_C_CONTIGUOUS_ = 0x0001,
- NPY_F_CONTIGUOUS_ = 0x0002,
+ NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
+ NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_OWNDATA_ = 0x0004,
NPY_ARRAY_FORCECAST_ = 0x0010,
- NPY_ENSURE_ARRAY_ = 0x0040,
+ NPY_ARRAY_ENSUREARRAY_ = 0x0040,
NPY_ARRAY_ALIGNED_ = 0x0100,
NPY_ARRAY_WRITEABLE_ = 0x0400,
NPY_BOOL_ = 0,
@@ -155,6 +157,7 @@ struct npy_api {
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
PyObject *(*PyArray_Squeeze_)(PyObject *);
+ int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
private:
enum functions {
API_PyArray_Type = 2,
@@ -169,7 +172,8 @@ private:
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
- API_PyArray_Squeeze = 136
+ API_PyArray_Squeeze = 136,
+ API_PyArray_SetBaseObject = 282
};
static npy_api lookup() {
@@ -195,6 +199,7 @@ private:
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
+ DECL_NPY_API(PyArray_SetBaseObject);
#undef DECL_NPY_API
return api;
}
@@ -220,6 +225,128 @@ inline bool check_flags(const void* ptr, int flag) {
return (flag == (array_proxy(ptr)->flags & flag));
}
+template <typename T> struct is_std_array : std::false_type { };
+template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
+template <typename T> struct is_complex : std::false_type { };
+template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
+
+template <typename T> using is_pod_struct = all_of<
+ std::is_pod<T>, // since we're accessing directly in memory we need a POD type
+ satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
+>;
+
+template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
+template <size_t Dim = 0, typename Strides, typename... Ix>
+size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
+ return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
+}
+
+/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
+ * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
+ * will be -1 for dimensions determined at runtime.
+ */
+template <typename T, ssize_t Dims>
+class unchecked_reference {
+protected:
+ static constexpr bool Dynamic = Dims < 0;
+ const unsigned char *data_;
+ // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
+ // make large performance gains on big, nested loops, but requires compile-time dimensions
+ conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>>
+ shape_, strides_;
+ const size_t dims_;
+
+ friend class pybind11::array;
+ // Constructor for compile-time dimensions:
+ template <bool Dyn = Dynamic>
+ unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>)
+ : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
+ for (size_t i = 0; i < dims_; i++) {
+ shape_[i] = shape[i];
+ strides_[i] = strides[i];
+ }
+ }
+ // Constructor for runtime dimensions:
+ template <bool Dyn = Dynamic>
+ unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims)
+ : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
+
+public:
+ /** Unchecked const reference access to data at the given indices. For a compile-time known
+ * number of dimensions, this requires the correct number of arguments; for run-time
+ * dimensionality, this is not checked (and so is up to the caller to use safely).
+ */
+ template <typename... Ix> const T &operator()(Ix... index) const {
+ static_assert(sizeof...(Ix) == Dims || Dynamic,
+ "Invalid number of indices for unchecked array reference");
+ return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...));
+ }
+ /** Unchecked const reference access to data; this operator only participates if the reference
+ * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
+ */
+ template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
+ const T &operator[](size_t index) const { return operator()(index); }
+
+ /// Pointer access to the data at the given indices.
+ template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); }
+
+ /// Returns the item size, i.e. sizeof(T)
+ constexpr static size_t itemsize() { return sizeof(T); }
+
+ /// Returns the shape (i.e. size) of dimension `dim`
+ size_t shape(size_t dim) const { return shape_[dim]; }
+
+ /// Returns the number of dimensions of the array
+ size_t ndim() const { return dims_; }
+
+ /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
+ template <bool Dyn = Dynamic>
+ enable_if_t<!Dyn, size_t> size() const {
+ return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>());
+ }
+ template <bool Dyn = Dynamic>
+ enable_if_t<Dyn, size_t> size() const {
+ return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>());
+ }
+
+ /// Returns the total number of bytes used by the referenced data. Note that the actual span in
+ /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
+ size_t nbytes() const {
+ return size() * itemsize();
+ }
+};
+
+template <typename T, ssize_t Dims>
+class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
+ friend class pybind11::array;
+ using ConstBase = unchecked_reference<T, Dims>;
+ using ConstBase::ConstBase;
+ using ConstBase::Dynamic;
+public:
+ /// Mutable, unchecked access to data at the given indices.
+ template <typename... Ix> T& operator()(Ix... index) {
+ static_assert(sizeof...(Ix) == Dims || Dynamic,
+ "Invalid number of indices for unchecked array reference");
+ return const_cast<T &>(ConstBase::operator()(index...));
+ }
+ /** Mutable, unchecked access data at the given index; this operator only participates if the
+ * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
+ * exactly equivalent to `obj(index)`.
+ */
+ template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
+ T &operator[](size_t index) { return operator()(index); }
+
+ /// Mutable pointer access to the data at the given indices.
+ template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); }
+};
+
+template <typename T, size_t Dim>
+struct type_caster<unchecked_reference<T, Dim>> {
+ static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
+};
+template <typename T, size_t Dim>
+struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
+
NAMESPACE_END(detail)
class dtype : public object {
@@ -321,8 +448,8 @@ public:
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
enum {
- c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
- f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
+ c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
+ f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};
@@ -340,7 +467,7 @@ public:
int flags = 0;
if (base && ptr) {
if (isinstance<array>(base))
- /* Copy flags from base (except baseship bit) */
+ /* Copy flags from base (except ownership bit) */
flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
else
/* Writable by default, easy to downgrade later on if needed */
@@ -348,13 +475,15 @@ public:
}
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
- api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
- (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr));
+ api.PyArray_Type_, descr.release().ptr(), (int) ndim,
+ reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
+ reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
+ const_cast<void *>(ptr), flags, nullptr));
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (ptr) {
if (base) {
- detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
+ api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
} else {
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
}
@@ -373,7 +502,7 @@ public:
template<typename T> array(const std::vector<size_t>& shape,
const std::vector<size_t>& strides,
const T* ptr, handle base = handle())
- : array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
+ : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
template <typename T>
array(const std::vector<size_t> &shape, const T *ptr,
@@ -486,6 +615,31 @@ public:
return offset_at(index...) / itemsize();
}
+ /** Returns a proxy object that provides access to the array's data without bounds or
+ * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
+ * care: the array must not be destroyed or reshaped for the duration of the returned object,
+ * and the caller must take care not to access invalid dimensions or dimension indices.
+ */
+ template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
+ if (Dims >= 0 && ndim() != (size_t) Dims)
+ throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
+ "; expected " + std::to_string(Dims));
+ return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
+ }
+
+ /** Returns a proxy object that provides const access to the array's data without bounds or
+ * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
+ * underlying array have the `writable` flag. Use with care: the array must not be destroyed or
+ * reshaped for the duration of the returned object, and the caller must take care not to access
+ * invalid dimensions or dimension indices.
+ */
+ template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
+ if (Dims >= 0 && ndim() != (size_t) Dims)
+ throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
+ "; expected " + std::to_string(Dims));
+ return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
+ }
+
/// Return a new view with all of the dimensions of length 1 removed
array squeeze() {
auto& api = detail::npy_api::get();
@@ -511,18 +665,12 @@ protected:
template<typename... Ix> size_t byte_offset(Ix... index) const {
check_dimensions(index...);
- return byte_offset_unsafe(index...);
- }
-
- template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
- return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
+ return detail::byte_offset_unsafe(strides(), size_t(index)...);
}
- template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
-
void check_writeable() const {
if (!writeable())
- throw std::runtime_error("array is not writeable");
+ throw std::domain_error("array is not writeable");
}
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
@@ -557,12 +705,14 @@ protected:
if (ptr == nullptr)
return nullptr;
return detail::npy_api::get().PyArray_FromAny_(
- ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
+ ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
}
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
public:
+ using value_type = T;
+
array_t() : array(0, static_cast<const T *>(nullptr)) {}
array_t(handle h, borrowed_t) : array(h, borrowed) { }
array_t(handle h, stolen_t) : array(h, stolen) { }
@@ -621,8 +771,27 @@ public:
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
}
- /// Ensure that the argument is a NumPy array of the correct dtype.
- /// In case of an error, nullptr is returned and the Python error is cleared.
+ /** Returns a proxy object that provides access to the array's data without bounds or
+ * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
+ * care: the array must not be destroyed or reshaped for the duration of the returned object,
+ * and the caller must take care not to access invalid dimensions or dimension indices.
+ */
+ template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
+ return array::mutable_unchecked<T, Dims>();
+ }
+
+ /** Returns a proxy object that provides const access to the array's data without bounds or
+ * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
+ * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
+ * for the duration of the returned object, and the caller must take care not to access invalid
+ * dimensions or dimension indices.
+ */
+ template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
+ return array::unchecked<T, Dims>();
+ }
+
+ /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
+ /// it). In case of an error, nullptr is returned and the Python error is cleared.
static array_t ensure(handle h) {
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
if (!result)
@@ -630,7 +799,7 @@ public:
return result;
}
- static bool _check(handle h) {
+ static bool check_(handle h) {
const auto &api = detail::npy_api::get();
return api.PyArray_Check_(h.ptr())
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
@@ -643,7 +812,7 @@ protected:
return nullptr;
return detail::npy_api::get().PyArray_FromAny_(
ptr, dtype::of<T>().release().ptr(), 0, 0,
- detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
+ detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
}
};
@@ -674,7 +843,9 @@ template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
using type = array_t<T, ExtraFlags>;
- bool load(handle src, bool /* convert */) {
+ bool load(handle src, bool convert) {
+ if (!convert && !type::check_(src))
+ return false;
value = type::ensure(src);
return static_cast<bool>(value);
}
@@ -685,65 +856,55 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
-template <typename T> struct is_std_array : std::false_type { };
-template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
-
template <typename T>
-struct is_pod_struct {
- enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
- !std::is_reference<T>::value &&
- !std::is_array<T>::value &&
- !is_std_array<T>::value &&
- !std::is_integral<T>::value &&
- !std::is_enum<T>::value &&
- !std::is_same<typename std::remove_cv<T>::type, float>::value &&
- !std::is_same<typename std::remove_cv<T>::type, double>::value &&
- !std::is_same<typename std::remove_cv<T>::type, bool>::value &&
- !std::is_same<typename std::remove_cv<T>::type, std::complex<float>>::value &&
- !std::is_same<typename std::remove_cv<T>::type, std::complex<double>>::value };
+struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
+ static bool compare(const buffer_info& b) {
+ return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
+ }
};
-template <typename T> struct npy_format_descriptor<T, enable_if_t<std::is_integral<T>::value>> {
+template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
private:
- constexpr static const int values[8] = {
- npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
- npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
+ // NB: the order here must match the one in common.h
+ constexpr static const int values[15] = {
+ npy_api::NPY_BOOL_,
+ npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
+ npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
+ npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
+ npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
+ };
+
public:
- enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
+ static constexpr int value = values[detail::is_fmt_numeric<T>::index];
+
static pybind11::dtype dtype() {
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
return reinterpret_borrow<pybind11::dtype>(ptr);
pybind11_fail("Unsupported buffer format!");
}
- template <typename T2 = T, enable_if_t<std::is_signed<T2>::value, int> = 0>
- static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
- template <typename T2 = T, enable_if_t<!std::is_signed<T2>::value, int> = 0>
- static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); }
+ template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
+ static PYBIND11_DESCR name() {
+ return _<std::is_same<T, bool>::value>(_("bool"),
+ _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
+ }
+ template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
+ static PYBIND11_DESCR name() {
+ return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
+ _("float") + _<sizeof(T)*8>(), _("longdouble"));
+ }
+ template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
+ static PYBIND11_DESCR name() {
+ return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
+ _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
+ }
};
-template <typename T> constexpr const int npy_format_descriptor<
- T, enable_if_t<std::is_integral<T>::value>>::values[8];
-
-#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
- enum { value = npy_api::NumPyName }; \
- static pybind11::dtype dtype() { \
- if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
- return reinterpret_borrow<pybind11::dtype>(ptr); \
- pybind11_fail("Unsupported buffer format!"); \
- } \
- static PYBIND11_DESCR name() { return _(Name); } }
-DECL_FMT(float, NPY_FLOAT_, "float32");
-DECL_FMT(double, NPY_DOUBLE_, "float64");
-DECL_FMT(bool, NPY_BOOL_, "bool");
-DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
-DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
-#undef DECL_FMT
-
-#define DECL_CHAR_FMT \
+
+#define PYBIND11_DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
-template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
-template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
-#undef DECL_CHAR_FMT
+template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
+template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
+#undef PYBIND11_DECL_CHAR_FMT
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
@@ -798,9 +959,9 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
- // mark unaligned fields with '='
+ // mark unaligned fields with '^' (unaligned native type)
if (field.offset % field.alignment)
- oss << '=';
+ oss << '^';
oss << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
@@ -820,9 +981,10 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
get_internals().direct_conversions[tindex].push_back(direct_converter);
}
-template <typename T>
-struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
- static PYBIND11_DESCR name() { return _("struct"); }
+template <typename T, typename SFINAE> struct npy_format_descriptor {
+ static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
+
+ static PYBIND11_DESCR name() { return make_caster<T>::name(); }
static pybind11::dtype dtype() {
return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
@@ -1043,87 +1205,146 @@ private:
std::array<common_iter, N> m_common_iterator;
};
+enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
+
+// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial
+// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
+// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
+// buffer; returns `non_trivial` otherwise.
template <size_t N>
-bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
+broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, size_t &ndim, std::vector<size_t> &shape) {
ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
return std::max(res, buf.ndim);
});
- shape = std::vector<size_t>(ndim, 1);
- bool trivial_broadcast = true;
+ shape.clear();
+ shape.resize(ndim, 1);
+
+ // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
+ // the full size).
for (size_t i = 0; i < N; ++i) {
auto res_iter = shape.rbegin();
- bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
- for (auto shape_iter = buffers[i].shape.rbegin();
- shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {
-
- if (*res_iter == 1)
- *res_iter = *shape_iter;
- else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
+ auto end = buffers[i].shape.rend();
+ for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
+ const auto &dim_size_in = *shape_iter;
+ auto &dim_size_out = *res_iter;
+
+ // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
+ if (dim_size_out == 1)
+ dim_size_out = dim_size_in;
+ else if (dim_size_in != 1 && dim_size_in != dim_size_out)
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
+ }
+ }
- i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
+ bool trivial_broadcast_c = true;
+ bool trivial_broadcast_f = true;
+ for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
+ if (buffers[i].size == 1)
+ continue;
+
+ // Require the same number of dimensions:
+ if (buffers[i].ndim != ndim)
+ return broadcast_trivial::non_trivial;
+
+ // Require all dimensions be full-size:
+ if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
+ return broadcast_trivial::non_trivial;
+
+ // Check for C contiguity (but only if previous inputs were also C contiguous)
+ if (trivial_broadcast_c) {
+ size_t expect_stride = buffers[i].itemsize;
+ auto end = buffers[i].shape.crend();
+ for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
+ trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
+ if (expect_stride == *stride_iter)
+ expect_stride *= *shape_iter;
+ else
+ trivial_broadcast_c = false;
+ }
+ }
+
+ // Check for Fortran contiguity (if previous inputs were also F contiguous)
+ if (trivial_broadcast_f) {
+ size_t expect_stride = buffers[i].itemsize;
+ auto end = buffers[i].shape.cend();
+ for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
+ trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
+ if (expect_stride == *stride_iter)
+ expect_stride *= *shape_iter;
+ else
+ trivial_broadcast_f = false;
+ }
}
- trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
}
- return trivial_broadcast;
+
+ return
+ trivial_broadcast_c ? broadcast_trivial::c_trivial :
+ trivial_broadcast_f ? broadcast_trivial::f_trivial :
+ broadcast_trivial::non_trivial;
}
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
typename std::remove_reference<Func>::type f;
+ static constexpr size_t N = sizeof...(Args);
template <typename T>
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
- object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
- return run(args..., make_index_sequence<sizeof...(Args)>());
+ object operator()(array_t<Args, array::forcecast>... args) {
+ return run(args..., make_index_sequence<N>());
}
- template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
+ template <size_t ... Index> object run(array_t<Args, array::forcecast>&... args, index_sequence<Index...> index) {
/* Request buffers from all parameters */
- const size_t N = sizeof...(Args);
-
std::array<buffer_info, N> buffers {{ args.request()... }};
/* Determine dimensions parameters of output array */
size_t ndim = 0;
std::vector<size_t> shape(0);
- bool trivial_broadcast = broadcast(buffers, ndim, shape);
+ auto trivial = broadcast(buffers, ndim, shape);
size_t size = 1;
std::vector<size_t> strides(ndim);
if (ndim > 0) {
- strides[ndim-1] = sizeof(Return);
- for (size_t i = ndim - 1; i > 0; --i) {
- strides[i - 1] = strides[i] * shape[i];
- size *= shape[i];
+ if (trivial == broadcast_trivial::f_trivial) {
+ strides[0] = sizeof(Return);
+ for (size_t i = 1; i < ndim; ++i) {
+ strides[i] = strides[i - 1] * shape[i - 1];
+ size *= shape[i - 1];
+ }
+ size *= shape[ndim - 1];
+ }
+ else {
+ strides[ndim-1] = sizeof(Return);
+ for (size_t i = ndim - 1; i > 0; --i) {
+ strides[i - 1] = strides[i] * shape[i];
+ size *= shape[i];
+ }
+ size *= shape[0];
}
- size *= shape[0];
}
if (size == 1)
- return cast(f(*((Args *) buffers[Index].ptr)...));
+ return cast(f(*reinterpret_cast<Args *>(buffers[Index].ptr)...));
array_t<Return> result(shape, strides);
auto buf = result.request();
auto output = (Return *) buf.ptr;
- if (trivial_broadcast) {
- /* Call the function */
- for (size_t i = 0; i < size; ++i) {
- output[i] = f((buffers[Index].size == 1
- ? *((Args *) buffers[Index].ptr)
- : ((Args *) buffers[Index].ptr)[i])...);
- }
+ /* Call the function */
+ if (trivial == broadcast_trivial::non_trivial) {
+ apply_broadcast<Index...>(buffers, buf, index);
} else {
- apply_broadcast<N, Index...>(buffers, buf, index);
+ for (size_t i = 0; i < size; ++i)
+ output[i] = f((reinterpret_cast<Args *>(buffers[Index].ptr)[buffers[Index].size == 1 ? 0 : i])...);
}
return result;
}
- template <size_t N, size_t... Index>
+ template <size_t... Index>
void apply_broadcast(const std::array<buffer_info, N> &buffers,
buffer_info &output, index_sequence<Index...>) {
using input_iterator = multi_array_iterator<N>;
@@ -1140,26 +1361,29 @@ struct vectorize_helper {
};
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
- static PYBIND11_DESCR name() { return _("numpy.ndarray[") + type_caster<T>::name() + _("]"); }
+ static PYBIND11_DESCR name() {
+ return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
+ }
};
NAMESPACE_END(detail)
template <typename Func, typename Return, typename... Args>
-detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
+detail::vectorize_helper<Func, Return, Args...>
+vectorize(const Func &f, Return (*) (Args ...)) {
return detail::vectorize_helper<Func, Return, Args...>(f);
}
template <typename Return, typename... Args>
-detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
+detail::vectorize_helper<Return (*) (Args ...), Return, Args...>
+vectorize(Return (*f) (Args ...)) {
return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
}
-template <typename Func>
+template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type>
auto vectorize(Func &&f) -> decltype(
- vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type *) nullptr)) {
- return vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(
- &std::remove_reference<Func>::type::operator())>::type *) nullptr);
+ vectorize(std::forward<Func>(f), (FuncType *) nullptr)) {
+ return vectorize(std::forward<Func>(f), (FuncType *) nullptr);
}
NAMESPACE_END(pybind11)