summaryrefslogtreecommitdiff
path: root/ext/pybind11/tests/test_sequences_and_iterators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ext/pybind11/tests/test_sequences_and_iterators.py')
-rw-r--r--ext/pybind11/tests/test_sequences_and_iterators.py41
1 files changed, 38 insertions, 3 deletions
diff --git a/ext/pybind11/tests/test_sequences_and_iterators.py b/ext/pybind11/tests/test_sequences_and_iterators.py
index 76b9f43f6..30b6aaf4b 100644
--- a/ext/pybind11/tests/test_sequences_and_iterators.py
+++ b/ext/pybind11/tests/test_sequences_and_iterators.py
@@ -11,7 +11,7 @@ def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0):
def test_generalized_iterators():
- from pybind11_tests import IntPairs
+ from pybind11_tests.sequences_and_iterators import IntPairs
assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
@@ -23,7 +23,8 @@ def test_generalized_iterators():
def test_sequence():
- from pybind11_tests import Sequence, ConstructorStats
+ from pybind11_tests import ConstructorStats
+ from pybind11_tests.sequences_and_iterators import Sequence
cstats = ConstructorStats.get(Sequence)
@@ -71,7 +72,7 @@ def test_sequence():
def test_map_iterator():
- from pybind11_tests import StringMap
+ from pybind11_tests.sequences_and_iterators import StringMap
m = StringMap({'hi': 'bye', 'black': 'white'})
assert m['hi'] == 'bye'
@@ -88,3 +89,37 @@ def test_map_iterator():
assert m[k] == expected[k]
for k, v in m.items():
assert v == expected[k]
+
+
+def test_python_iterator_in_cpp():
+ import pybind11_tests.sequences_and_iterators as m
+
+ t = (1, 2, 3)
+ assert m.object_to_list(t) == [1, 2, 3]
+ assert m.object_to_list(iter(t)) == [1, 2, 3]
+ assert m.iterator_to_list(iter(t)) == [1, 2, 3]
+
+ with pytest.raises(TypeError) as excinfo:
+ m.object_to_list(1)
+ assert "object is not iterable" in str(excinfo.value)
+
+ with pytest.raises(TypeError) as excinfo:
+ m.iterator_to_list(1)
+ assert "incompatible function arguments" in str(excinfo.value)
+
+ def bad_next_call():
+ raise RuntimeError("py::iterator::advance() should propagate errors")
+
+ with pytest.raises(RuntimeError) as excinfo:
+ m.iterator_to_list(iter(bad_next_call, None))
+ assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
+
+ l = [1, None, 0, None]
+ assert m.count_none(l) == 2
+ assert m.find_none(l) is True
+ assert m.count_nonzeros({"a": 0, "b": 1, "c": 2}) == 2
+
+ r = range(5)
+ assert all(m.tuple_iterator(tuple(r)))
+ assert all(m.list_iterator(list(r)))
+ assert all(m.sequence_iterator(r))