diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py index 84bd3f4a..cc6715b6 100644 --- a/src/graphql/execution/map_async_iterable.py +++ b/src/graphql/execution/map_async_iterable.py @@ -1,30 +1,34 @@ from __future__ import annotations # Python < 3.10 -from asyncio import CancelledError, Event, Task, ensure_future, wait -from concurrent.futures import FIRST_COMPLETED -from inspect import isasyncgen, isawaitable from types import TracebackType -from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast +from typing import Any, AsyncIterable, Awaitable, Callable, Optional, Type, Union __all__ = ["MapAsyncIterable"] +# The following is a class because its type is checked in the code. +# otherwise, it could be implemented as a simple async generator function + + # noinspection PyAttributeOutsideInit class MapAsyncIterable: """Map an AsyncIterable over a callback function. Given an AsyncIterable and a callback function, return an AsyncIterator which - produces values mapped via calling the callback function. + produces values mapped via calling the callback async function. - When the resulting AsyncIterator is closed, the underlying AsyncIterable will also - be closed. + Similar to an AsyncGenerator, an `aclose()` method is provivde which + will close the underlying AsyncIterable be if it has an `aclose()` method. """ - def __init__(self, iterable: AsyncIterable, callback: Callable) -> None: + def __init__( + self, iterable: AsyncIterable[Any], callback: Callable[[Any], Awaitable[Any]] + ) -> None: self.iterator = iterable.__aiter__() self.callback = callback - self._close_event = Event() + self._ageniter = self._agen() + self.is_closed = False # used by unittests def __aiter__(self) -> MapAsyncIterable: """Get the iterator object.""" @@ -32,41 +36,18 @@ def __aiter__(self) -> MapAsyncIterable: async def __anext__(self) -> Any: """Get the next value of the iterator.""" - if self.is_closed: - if not isasyncgen(self.iterator): - raise StopAsyncIteration - value = await self.iterator.__anext__() - else: - aclose = ensure_future(self._close_event.wait()) - anext = ensure_future(self.iterator.__anext__()) - - try: - pending: Set[Task] = ( - await wait([aclose, anext], return_when=FIRST_COMPLETED) - )[1] - except CancelledError: - # cancel underlying tasks and close - aclose.cancel() - anext.cancel() - await self.aclose() - raise # re-raise the cancellation - - for task in pending: - task.cancel() - - if aclose.done(): - raise StopAsyncIteration - - error = anext.exception() - if error: - raise error + return await self._ageniter.__anext__() - value = anext.result() - - result = self.callback(value) - - return await result if isawaitable(result) else result + async def _agen(self) -> Any: + try: + async for v in self.iterator: + yield await self.callback(v) + finally: + self.is_closed = True + if hasattr(self.iterator, "aclose"): + await self.iterator.aclose() + # This is not a standard method and is only used in unittests. Should be removed. async def athrow( self, type_: Union[BaseException, Type[BaseException]], @@ -74,45 +55,8 @@ async def athrow( traceback: Optional[TracebackType] = None, ) -> None: """Throw an exception into the asynchronous iterator.""" - if self.is_closed: - return - athrow = getattr(self.iterator, "athrow", None) - if athrow: - await athrow(type_, value, traceback) - else: - await self.aclose() - if value is None: - if traceback is None: - raise type_ - value = ( - type_ - if isinstance(value, BaseException) - else cast(Type[BaseException], type_)() - ) - if traceback is not None: - value = value.with_traceback(traceback) - raise value + await self._ageniter.athrow(type_, value, traceback) async def aclose(self) -> None: """Close the iterator.""" - if not self.is_closed: - aclose = getattr(self.iterator, "aclose", None) - if aclose: - try: - await aclose() - except RuntimeError: - pass - self.is_closed = True - - @property - def is_closed(self) -> bool: - """Check whether the iterator is closed.""" - return self._close_event.is_set() - - @is_closed.setter - def is_closed(self, value: bool) -> None: - """Mark the iterator as closed.""" - if value: - self._close_event.set() - else: - self._close_event.clear() + await self._ageniter.aclose() diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index 6406f7dd..09a0b28d 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -18,6 +18,14 @@ async def anext(iterator): return await iterator.__anext__() +async def map_single(x): + return x + + +async def map_doubles(x): + return x + x + + def describe_map_async_iterable(): @mark.asyncio async def maps_over_async_generator(): @@ -26,7 +34,7 @@ async def source(): yield 2 yield 3 - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -48,7 +56,7 @@ async def __anext__(self): except IndexError: raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = MapAsyncIterable(Iterable(), map_doubles) values = [value async for value in doubles] @@ -62,7 +70,7 @@ async def source(): yield 2 yield 3 - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) values = [value async for value in doubles] @@ -91,7 +99,7 @@ async def source(): yield 2 yield 3 # pragma: no cover - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -119,7 +127,7 @@ async def __anext__(self): except IndexError: # pragma: no cover raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = MapAsyncIterable(Iterable(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -133,8 +141,9 @@ async def __anext__(self): with raises(StopAsyncIteration): await anext(doubles) + # async iterators must not yield after aclose() is called @mark.asyncio - async def passes_through_early_return_from_async_values(): + async def ignored_generator_exit(): async def source(): try: yield 1 @@ -142,20 +151,16 @@ async def source(): yield 3 # pragma: no cover finally: yield "Done" - yield "Last" + yield "Last" # pragma: no cover - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 - # Early return - await doubles.aclose() - - # Subsequent next calls may yield from finally block - assert await anext(doubles) == "LastLast" - with raises(GeneratorExit): - assert await anext(doubles) + with raises(RuntimeError) as exc_info: + await doubles.aclose() + assert str(exc_info.value) == "async generator ignored GeneratorExit" @mark.asyncio async def allows_throwing_errors_through_async_iterable(): @@ -171,7 +176,7 @@ async def __anext__(self): except IndexError: # pragma: no cover raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = MapAsyncIterable(Iterable(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -197,7 +202,7 @@ def __aiter__(self): async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) + one = MapAsyncIterable(Iterable(), map_single) assert await anext(one) == 1 @@ -223,7 +228,7 @@ def __aiter__(self): async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) + one = MapAsyncIterable(Iterable(), map_single) assert await anext(one) == 1 @@ -250,18 +255,14 @@ async def source(): except Exception as e: yield e - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Throw error - await doubles.athrow(RuntimeError("ouch")) - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) + with raises(RuntimeError): + await doubles.athrow(RuntimeError("ouch")) @mark.asyncio async def does_not_normally_map_over_thrown_errors(): @@ -269,7 +270,7 @@ async def source(): yield "Hello" raise RuntimeError("Goodbye") - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == "HelloHello" @@ -283,7 +284,7 @@ async def does_not_normally_map_over_externally_thrown_errors(): async def source(): yield "Hello" - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == "HelloHello" @@ -312,7 +313,7 @@ async def __anext__(self): raise StopAsyncIteration return self.counter - def double(x): + async def double(x): return x + x for iterable in source, Source: @@ -384,7 +385,11 @@ async def source(): yield 3 # pragma: no cover singles = source() - doubles = MapAsyncIterable(singles, lambda x: x * 2) + + async def double(x): + return x * 2 + + doubles = MapAsyncIterable(singles, double) result = await anext(doubles) assert result == 2 @@ -394,65 +399,28 @@ async def source(): await sleep(0.05) assert not doubles_future.done() - # Unblock and watch StopAsyncIteration propagate - await doubles.aclose() - await sleep(0.05) - assert doubles_future.done() - assert isinstance(doubles_future.exception(), StopAsyncIteration) + # with python 3.8 and higher, close() cannot be used to unblock a generator. + # instead, the task should be killed. AsyncGenerators are not re-entrant. + if sys.version_info[:2] >= (3, 8): + with raises(RuntimeError): + await doubles.aclose() + doubles_future.cancel() + await sleep(0.05) + assert doubles_future.done() + with raises(CancelledError): + doubles_future.exception() + + else: + # old behaviour, where aclose() could unblock a Task + # Unblock and watch StopAsyncIteration propagate + await doubles.aclose() + await sleep(0.05) + assert doubles_future.done() + assert isinstance(doubles_future.exception(), StopAsyncIteration) with raises(StopAsyncIteration): await anext(singles) - @mark.asyncio - async def can_unset_closed_state_of_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __init__(self): - self.is_closed = False - - def __aiter__(self): - return self - - async def __anext__(self): - if self.is_closed: - raise StopAsyncIteration - try: - return items.pop(0) - except IndexError: - raise StopAsyncIteration - - async def aclose(self): - self.is_closed = True - - iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert not iterable.is_closed - await doubles.aclose() - assert iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert doubles.is_closed - - iterable.is_closed = False - doubles.is_closed = False - assert not doubles.is_closed - - assert await anext(doubles) == 6 - assert not doubles.is_closed - assert not iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert not doubles.is_closed - assert not iterable.is_closed - @mark.asyncio async def can_cancel_async_iterable_while_waiting(): class Iterable: @@ -475,7 +443,7 @@ async def aclose(self): self.is_closed = True iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) # pragma: no cover exit + doubles = MapAsyncIterable(iterable, map_doubles) # pragma: no cover exit cancelled = False async def iterator_task():