summaryrefslogtreecommitdiff
path: root/ext/pybind11/tests/test_numpy_array.py
blob: 1c218a10b8e2e47e2d6e4bb749001ed4460bc4c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import pytest
import gc

with pytest.suppress(ImportError):
    import numpy as np


@pytest.fixture(scope='function')
def arr():
    return np.array([[1, 2, 3], [4, 5, 6]], '<u2')


@pytest.requires_numpy
def test_array_attributes():
    from pybind11_tests.array import (
        ndim, shape, strides, writeable, size, itemsize, nbytes, owndata
    )

    a = np.array(0, 'f8')
    assert ndim(a) == 0
    assert all(shape(a) == [])
    assert all(strides(a) == [])
    with pytest.raises(IndexError) as excinfo:
        shape(a, 0)
    assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
    with pytest.raises(IndexError) as excinfo:
        strides(a, 0)
    assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
    assert writeable(a)
    assert size(a) == 1
    assert itemsize(a) == 8
    assert nbytes(a) == 8
    assert owndata(a)

    a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
    a.flags.writeable = False
    assert ndim(a) == 2
    assert all(shape(a) == [2, 3])
    assert shape(a, 0) == 2
    assert shape(a, 1) == 3
    assert all(strides(a) == [6, 2])
    assert strides(a, 0) == 6
    assert strides(a, 1) == 2
    with pytest.raises(IndexError) as excinfo:
        shape(a, 2)
    assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
    with pytest.raises(IndexError) as excinfo:
        strides(a, 2)
    assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
    assert not writeable(a)
    assert size(a) == 6
    assert itemsize(a) == 2
    assert nbytes(a) == 12
    assert not owndata(a)


@pytest.requires_numpy
@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)])
def test_index_offset(arr, args, ret):
    from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t
    assert index_at(arr, *args) == ret
    assert index_at_t(arr, *args) == ret
    assert offset_at(arr, *args) == ret * arr.dtype.itemsize
    assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize


@pytest.requires_numpy
def test_dim_check_fail(arr):
    from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
                                      mutate_data, mutate_data_t)
    for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
                 mutate_data, mutate_data_t):
        with pytest.raises(IndexError) as excinfo:
            func(arr, 1, 2, 3)
        assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)'


@pytest.requires_numpy
@pytest.mark.parametrize('args, ret',
                         [([], [1, 2, 3, 4, 5, 6]),
                          ([1], [4, 5, 6]),
                          ([0, 1], [2, 3, 4, 5, 6]),
                          ([1, 2], [6])])
def test_data(arr, args, ret):
    from pybind11_tests.array import data, data_t
    assert all(data_t(arr, *args) == ret)
    assert all(data(arr, *args)[::2] == ret)
    assert all(data(arr, *args)[1::2] == 0)


@pytest.requires_numpy
def test_mutate_readonly(arr):
    from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t
    arr.flags.writeable = False
    for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)):
        with pytest.raises(RuntimeError) as excinfo:
            func(arr, *args)
        assert str(excinfo.value) == 'array is not writeable'


@pytest.requires_numpy
@pytest.mark.parametrize('dim', [0, 1, 3])
def test_at_fail(arr, dim):
    from pybind11_tests.array import at_t, mutate_at_t
    for func in at_t, mutate_at_t:
        with pytest.raises(IndexError) as excinfo:
            func(arr, *([0] * dim))
        assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim)


@pytest.requires_numpy
def test_at(arr):
    from pybind11_tests.array import at_t, mutate_at_t

    assert at_t(arr, 0, 2) == 3
    assert at_t(arr, 1, 0) == 4

    assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
    assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])


@pytest.requires_numpy
def test_mutate_data(arr):
    from pybind11_tests.array import mutate_data, mutate_data_t

    assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
    assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
    assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
    assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
    assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])

    assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
    assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
    assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
    assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
    assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])


@pytest.requires_numpy
def test_bounds_check(arr):
    from pybind11_tests.array import (index_at, index_at_t, data, data_t,
                                      mutate_data, mutate_data_t, at_t, mutate_at_t)
    funcs = (index_at, index_at_t, data, data_t,
             mutate_data, mutate_data_t, at_t, mutate_at_t)
    for func in funcs:
        with pytest.raises(IndexError) as excinfo:
            func(arr, 2, 0)
        assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2'
        with pytest.raises(IndexError) as excinfo:
            func(arr, 0, 4)
        assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'


@pytest.requires_numpy
def test_make_c_f_array():
    from pybind11_tests.array import (
        make_c_array, make_f_array
    )
    assert make_c_array().flags.c_contiguous
    assert not make_c_array().flags.f_contiguous
    assert make_f_array().flags.f_contiguous
    assert not make_f_array().flags.c_contiguous


@pytest.requires_numpy
def test_wrap():
    from pybind11_tests.array import wrap

    def assert_references(a, b):
        assert a is not b
        assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
        assert a.shape == b.shape
        assert a.strides == b.strides
        assert a.flags.c_contiguous == b.flags.c_contiguous
        assert a.flags.f_contiguous == b.flags.f_contiguous
        assert a.flags.writeable == b.flags.writeable
        assert a.flags.aligned == b.flags.aligned
        assert a.flags.updateifcopy == b.flags.updateifcopy
        assert np.all(a == b)
        assert not b.flags.owndata
        assert b.base is a
        if a.flags.writeable and a.ndim == 2:
            a[0, 0] = 1234
            assert b[0, 0] == 1234

    a1 = np.array([1, 2], dtype=np.int16)
    assert a1.flags.owndata and a1.base is None
    a2 = wrap(a1)
    assert_references(a1, a2)

    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
    assert a1.flags.owndata and a1.base is None
    a2 = wrap(a1)
    assert_references(a1, a2)

    a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
    a1.flags.writeable = False
    a2 = wrap(a1)
    assert_references(a1, a2)

    a1 = np.random.random((4, 4, 4))
    a2 = wrap(a1)
    assert_references(a1, a2)

    a1 = a1.transpose()
    a2 = wrap(a1)
    assert_references(a1, a2)

    a1 = a1.diagonal()
    a2 = wrap(a1)
    assert_references(a1, a2)


@pytest.requires_numpy
def test_numpy_view(capture):
    from pybind11_tests.array import ArrayClass
    with capture:
        ac = ArrayClass()
        ac_view_1 = ac.numpy_view()
        ac_view_2 = ac.numpy_view()
        assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
        del ac
        gc.collect()
    assert capture == """
        ArrayClass()
        ArrayClass::numpy_view()
        ArrayClass::numpy_view()
    """
    ac_view_1[0] = 4
    ac_view_1[1] = 3
    assert ac_view_2[0] == 4
    assert ac_view_2[1] == 3
    with capture:
        del ac_view_1
        del ac_view_2
        gc.collect()
    assert capture == """
        ~ArrayClass()
    """


@pytest.requires_numpy
def test_cast_numpy_int64_to_uint64():
    from pybind11_tests.array import function_taking_uint64
    function_taking_uint64(123)
    function_taking_uint64(np.uint64(123))


@pytest.requires_numpy
def test_isinstance():
    from pybind11_tests.array import isinstance_untyped, isinstance_typed

    assert isinstance_untyped(np.array([1, 2, 3]), "not an array")
    assert isinstance_typed(np.array([1.0, 2.0, 3.0]))


@pytest.requires_numpy
def test_constructors():
    from pybind11_tests.array import default_constructors, converting_constructors

    defaults = default_constructors()
    for a in defaults.values():
        assert a.size == 0
    assert defaults["array"].dtype == np.array([]).dtype
    assert defaults["array_t<int32>"].dtype == np.int32
    assert defaults["array_t<double>"].dtype == np.float64

    results = converting_constructors([1, 2, 3])
    for a in results.values():
        np.testing.assert_array_equal(a, [1, 2, 3])
    assert results["array"].dtype == np.int_
    assert results["array_t<int32>"].dtype == np.int32
    assert results["array_t<double>"].dtype == np.float64