Skip to content

Commit daf5394

Browse files
committed
Rewrite MapAsyncIterable using async generator semantics
1 parent a9b9568 commit daf5394

File tree

2 files changed

+48
-149
lines changed

2 files changed

+48
-149
lines changed
Lines changed: 22 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations # Python < 3.10
22

3-
from asyncio import CancelledError, Event, Task, ensure_future, wait
4-
from concurrent.futures import FIRST_COMPLETED
5-
from inspect import isasyncgen, isawaitable
3+
from inspect import isawaitable
64
from types import TracebackType
7-
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast
5+
from typing import Any, AsyncIterable, Callable, Optional, Type, Union
86

97

108
__all__ = ["MapAsyncIterable"]
119

1210

11+
# The following is a class because its type is checked in the code.
12+
# otherwise, it could be implemented as a simple async generator function
13+
1314
# noinspection PyAttributeOutsideInit
1415
class MapAsyncIterable:
1516
"""Map an AsyncIterable over a callback function.
@@ -22,97 +23,39 @@ class MapAsyncIterable:
2223
"""
2324

2425
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
25-
self.iterator = iterable.__aiter__()
26+
self.iterable = iterable
2627
self.callback = callback
27-
self._close_event = Event()
28+
self._ageniter = self._agen()
29+
self.is_closed = False # used by unittests
2830

2931
def __aiter__(self) -> MapAsyncIterable:
3032
"""Get the iterator object."""
3133
return self
3234

3335
async def __anext__(self) -> Any:
3436
"""Get the next value of the iterator."""
35-
if self.is_closed:
36-
if not isasyncgen(self.iterator):
37-
raise StopAsyncIteration
38-
value = await self.iterator.__anext__()
39-
else:
40-
aclose = ensure_future(self._close_event.wait())
41-
anext = ensure_future(self.iterator.__anext__())
42-
43-
try:
44-
pending: Set[Task] = (
45-
await wait([aclose, anext], return_when=FIRST_COMPLETED)
46-
)[1]
47-
except CancelledError:
48-
# cancel underlying tasks and close
49-
aclose.cancel()
50-
anext.cancel()
51-
await self.aclose()
52-
raise # re-raise the cancellation
53-
54-
for task in pending:
55-
task.cancel()
56-
57-
if aclose.done():
58-
raise StopAsyncIteration
59-
60-
error = anext.exception()
61-
if error:
62-
raise error
63-
64-
value = anext.result()
65-
66-
result = self.callback(value)
67-
68-
return await result if isawaitable(result) else result
37+
return await self._ageniter.__anext__()
38+
39+
async def _agen(self) -> Any:
40+
try:
41+
async for v in self.iterable:
42+
result = self.callback(v)
43+
yield (await result) if isawaitable(result) else result
44+
finally:
45+
self.is_closed = True
46+
if hasattr(self.iterable, "aclose"):
47+
await self.iterable.aclose()
6948

49+
# This is not a standard method and is only used in unittests. Should be removed.
7050
async def athrow(
7151
self,
7252
type_: Union[BaseException, Type[BaseException]],
7353
value: Optional[BaseException] = None,
7454
traceback: Optional[TracebackType] = None,
7555
) -> None:
7656
"""Throw an exception into the asynchronous iterator."""
77-
if self.is_closed:
78-
return
79-
athrow = getattr(self.iterator, "athrow", None)
80-
if athrow:
81-
await athrow(type_, value, traceback)
82-
else:
83-
await self.aclose()
84-
if value is None:
85-
if traceback is None:
86-
raise type_
87-
value = (
88-
type_
89-
if isinstance(value, BaseException)
90-
else cast(Type[BaseException], type_)()
91-
)
92-
if traceback is not None:
93-
value = value.with_traceback(traceback)
94-
raise value
57+
await self._ageniter.athrow(type_, value, traceback)
9558

9659
async def aclose(self) -> None:
9760
"""Close the iterator."""
98-
if not self.is_closed:
99-
aclose = getattr(self.iterator, "aclose", None)
100-
if aclose:
101-
try:
102-
await aclose()
103-
except RuntimeError:
104-
pass
105-
self.is_closed = True
106-
107-
@property
108-
def is_closed(self) -> bool:
109-
"""Check whether the iterator is closed."""
110-
return self._close_event.is_set()
111-
112-
@is_closed.setter
113-
def is_closed(self, value: bool) -> None:
114-
"""Mark the iterator as closed."""
115-
if value:
116-
self._close_event.set()
117-
else:
118-
self._close_event.clear()
61+
await self._ageniter.aclose()

tests/execution/test_map_async_iterable.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -133,29 +133,26 @@ async def __anext__(self):
133133
with raises(StopAsyncIteration):
134134
await anext(doubles)
135135

136+
# async iterators must not yield after aclose() is called
136137
@mark.asyncio
137-
async def passes_through_early_return_from_async_values():
138+
async def ignored_generator_exit():
138139
async def source():
139140
try:
140141
yield 1
141142
yield 2
142143
yield 3 # pragma: no cover
143144
finally:
144145
yield "Done"
145-
yield "Last"
146+
yield "Last" # pragma: no cover
146147

147148
doubles = MapAsyncIterable(source(), lambda x: x + x)
148149

149150
assert await anext(doubles) == 2
150151
assert await anext(doubles) == 4
151152

152-
# Early return
153-
await doubles.aclose()
154-
155-
# Subsequent next calls may yield from finally block
156-
assert await anext(doubles) == "LastLast"
157-
with raises(GeneratorExit):
158-
assert await anext(doubles)
153+
with raises(RuntimeError) as exc_info:
154+
await doubles.aclose()
155+
assert str(exc_info.value) == "async generator ignored GeneratorExit"
159156

160157
@mark.asyncio
161158
async def allows_throwing_errors_through_async_iterable():
@@ -256,12 +253,8 @@ async def source():
256253
assert await anext(doubles) == 4
257254

258255
# Throw error
259-
await doubles.athrow(RuntimeError("ouch"))
260-
261-
with raises(StopAsyncIteration):
262-
await anext(doubles)
263-
with raises(StopAsyncIteration):
264-
await anext(doubles)
256+
with raises(RuntimeError):
257+
await doubles.athrow(RuntimeError("ouch"))
265258

266259
@mark.asyncio
267260
async def does_not_normally_map_over_thrown_errors():
@@ -394,65 +387,28 @@ async def source():
394387
await sleep(0.05)
395388
assert not doubles_future.done()
396389

397-
# Unblock and watch StopAsyncIteration propagate
398-
await doubles.aclose()
399-
await sleep(0.05)
400-
assert doubles_future.done()
401-
assert isinstance(doubles_future.exception(), StopAsyncIteration)
390+
# with python 3.8 and higher, close() cannot be used to unblock a generator.
391+
# instead, the task should be killed. AsyncGenerators are not re-entrant.
392+
if sys.version_info[:2] >= (3, 8):
393+
with raises(RuntimeError):
394+
await doubles.aclose()
395+
doubles_future.cancel()
396+
await sleep(0.05)
397+
assert doubles_future.done()
398+
with raises(CancelledError):
399+
doubles_future.exception()
400+
401+
else:
402+
# old behaviour, where aclose() could unblock a Task
403+
# Unblock and watch StopAsyncIteration propagate
404+
await doubles.aclose()
405+
await sleep(0.05)
406+
assert doubles_future.done()
407+
assert isinstance(doubles_future.exception(), StopAsyncIteration)
402408

403409
with raises(StopAsyncIteration):
404410
await anext(singles)
405411

406-
@mark.asyncio
407-
async def can_unset_closed_state_of_async_iterable():
408-
items = [1, 2, 3]
409-
410-
class Iterable:
411-
def __init__(self):
412-
self.is_closed = False
413-
414-
def __aiter__(self):
415-
return self
416-
417-
async def __anext__(self):
418-
if self.is_closed:
419-
raise StopAsyncIteration
420-
try:
421-
return items.pop(0)
422-
except IndexError:
423-
raise StopAsyncIteration
424-
425-
async def aclose(self):
426-
self.is_closed = True
427-
428-
iterable = Iterable()
429-
doubles = MapAsyncIterable(iterable, lambda x: x + x)
430-
431-
assert await anext(doubles) == 2
432-
assert await anext(doubles) == 4
433-
assert not iterable.is_closed
434-
await doubles.aclose()
435-
assert iterable.is_closed
436-
with raises(StopAsyncIteration):
437-
await anext(iterable)
438-
with raises(StopAsyncIteration):
439-
await anext(doubles)
440-
assert doubles.is_closed
441-
442-
iterable.is_closed = False
443-
doubles.is_closed = False
444-
assert not doubles.is_closed
445-
446-
assert await anext(doubles) == 6
447-
assert not doubles.is_closed
448-
assert not iterable.is_closed
449-
with raises(StopAsyncIteration):
450-
await anext(iterable)
451-
with raises(StopAsyncIteration):
452-
await anext(doubles)
453-
assert not doubles.is_closed
454-
assert not iterable.is_closed
455-
456412
@mark.asyncio
457413
async def can_cancel_async_iterable_while_waiting():
458414
class Iterable:

0 commit comments

Comments
 (0)