From e86f811bd70d0b518361ac4f359615cf303a35ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 21 Apr 2023 14:29:28 +0000 Subject: [PATCH 1/3] Rewrite MapAsyncIterable using async generator semantics --- src/graphql/execution/map_async_iterable.py | 101 +++++--------------- tests/execution/test_map_async_iterable.py | 96 +++++-------------- 2 files changed, 48 insertions(+), 149 deletions(-) diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py index 84bd3f4a..885fc1bc 100644 --- a/src/graphql/execution/map_async_iterable.py +++ b/src/graphql/execution/map_async_iterable.py @@ -1,15 +1,16 @@ from __future__ import annotations # Python < 3.10 -from asyncio import CancelledError, Event, Task, ensure_future, wait -from concurrent.futures import FIRST_COMPLETED -from inspect import isasyncgen, isawaitable +from inspect import isawaitable from types import TracebackType -from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast +from typing import Any, AsyncIterable, Callable, Optional, Type, Union __all__ = ["MapAsyncIterable"] +# The following is a class because its type is checked in the code. +# otherwise, it could be implemented as a simple async generator function + # noinspection PyAttributeOutsideInit class MapAsyncIterable: """Map an AsyncIterable over a callback function. @@ -22,9 +23,10 @@ class MapAsyncIterable: """ def __init__(self, iterable: AsyncIterable, callback: Callable) -> None: - self.iterator = iterable.__aiter__() + self.iterable = iterable self.callback = callback - self._close_event = Event() + self._ageniter = self._agen() + self.is_closed = False # used by unittests def __aiter__(self) -> MapAsyncIterable: """Get the iterator object.""" @@ -32,41 +34,19 @@ def __aiter__(self) -> MapAsyncIterable: async def __anext__(self) -> Any: """Get the next value of the iterator.""" - if self.is_closed: - if not isasyncgen(self.iterator): - raise StopAsyncIteration - value = await self.iterator.__anext__() - else: - aclose = ensure_future(self._close_event.wait()) - anext = ensure_future(self.iterator.__anext__()) - - try: - pending: Set[Task] = ( - await wait([aclose, anext], return_when=FIRST_COMPLETED) - )[1] - except CancelledError: - # cancel underlying tasks and close - aclose.cancel() - anext.cancel() - await self.aclose() - raise # re-raise the cancellation - - for task in pending: - task.cancel() - - if aclose.done(): - raise StopAsyncIteration - - error = anext.exception() - if error: - raise error - - value = anext.result() - - result = self.callback(value) - - return await result if isawaitable(result) else result + return await self._ageniter.__anext__() + + async def _agen(self) -> Any: + try: + async for v in self.iterable: + result = self.callback(v) + yield (await result) if isawaitable(result) else result + finally: + self.is_closed = True + if hasattr(self.iterable, "aclose"): + await self.iterable.aclose() + # This is not a standard method and is only used in unittests. Should be removed. async def athrow( self, type_: Union[BaseException, Type[BaseException]], @@ -74,45 +54,8 @@ async def athrow( traceback: Optional[TracebackType] = None, ) -> None: """Throw an exception into the asynchronous iterator.""" - if self.is_closed: - return - athrow = getattr(self.iterator, "athrow", None) - if athrow: - await athrow(type_, value, traceback) - else: - await self.aclose() - if value is None: - if traceback is None: - raise type_ - value = ( - type_ - if isinstance(value, BaseException) - else cast(Type[BaseException], type_)() - ) - if traceback is not None: - value = value.with_traceback(traceback) - raise value + await self._ageniter.athrow(type_, value, traceback) async def aclose(self) -> None: """Close the iterator.""" - if not self.is_closed: - aclose = getattr(self.iterator, "aclose", None) - if aclose: - try: - await aclose() - except RuntimeError: - pass - self.is_closed = True - - @property - def is_closed(self) -> bool: - """Check whether the iterator is closed.""" - return self._close_event.is_set() - - @is_closed.setter - def is_closed(self, value: bool) -> None: - """Mark the iterator as closed.""" - if value: - self._close_event.set() - else: - self._close_event.clear() + await self._ageniter.aclose() diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index 6406f7dd..aab1e7fd 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -133,8 +133,9 @@ async def __anext__(self): with raises(StopAsyncIteration): await anext(doubles) + # async iterators must not yield after aclose() is called @mark.asyncio - async def passes_through_early_return_from_async_values(): + async def ignored_generator_exit(): async def source(): try: yield 1 @@ -142,20 +143,16 @@ async def source(): yield 3 # pragma: no cover finally: yield "Done" - yield "Last" + yield "Last" # pragma: no cover doubles = MapAsyncIterable(source(), lambda x: x + x) assert await anext(doubles) == 2 assert await anext(doubles) == 4 - # Early return - await doubles.aclose() - - # Subsequent next calls may yield from finally block - assert await anext(doubles) == "LastLast" - with raises(GeneratorExit): - assert await anext(doubles) + with raises(RuntimeError) as exc_info: + await doubles.aclose() + assert str(exc_info.value) == "async generator ignored GeneratorExit" @mark.asyncio async def allows_throwing_errors_through_async_iterable(): @@ -256,12 +253,8 @@ async def source(): assert await anext(doubles) == 4 # Throw error - await doubles.athrow(RuntimeError("ouch")) - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) + with raises(RuntimeError): + await doubles.athrow(RuntimeError("ouch")) @mark.asyncio async def does_not_normally_map_over_thrown_errors(): @@ -394,65 +387,28 @@ async def source(): await sleep(0.05) assert not doubles_future.done() - # Unblock and watch StopAsyncIteration propagate - await doubles.aclose() - await sleep(0.05) - assert doubles_future.done() - assert isinstance(doubles_future.exception(), StopAsyncIteration) + # with python 3.8 and higher, close() cannot be used to unblock a generator. + # instead, the task should be killed. AsyncGenerators are not re-entrant. + if sys.version_info[:2] >= (3, 8): + with raises(RuntimeError): + await doubles.aclose() + doubles_future.cancel() + await sleep(0.05) + assert doubles_future.done() + with raises(CancelledError): + doubles_future.exception() + + else: + # old behaviour, where aclose() could unblock a Task + # Unblock and watch StopAsyncIteration propagate + await doubles.aclose() + await sleep(0.05) + assert doubles_future.done() + assert isinstance(doubles_future.exception(), StopAsyncIteration) with raises(StopAsyncIteration): await anext(singles) - @mark.asyncio - async def can_unset_closed_state_of_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __init__(self): - self.is_closed = False - - def __aiter__(self): - return self - - async def __anext__(self): - if self.is_closed: - raise StopAsyncIteration - try: - return items.pop(0) - except IndexError: - raise StopAsyncIteration - - async def aclose(self): - self.is_closed = True - - iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert not iterable.is_closed - await doubles.aclose() - assert iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert doubles.is_closed - - iterable.is_closed = False - doubles.is_closed = False - assert not doubles.is_closed - - assert await anext(doubles) == 6 - assert not doubles.is_closed - assert not iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert not doubles.is_closed - assert not iterable.is_closed - @mark.asyncio async def can_cancel_async_iterable_while_waiting(): class Iterable: From 0860e3b02b56aeda23a72e4e892faa7d4261dcb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 23 Apr 2023 18:52:48 +0000 Subject: [PATCH 2/3] iterate over and aclose() the iterator --- src/graphql/execution/map_async_iterable.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py index 885fc1bc..b99c17ba 100644 --- a/src/graphql/execution/map_async_iterable.py +++ b/src/graphql/execution/map_async_iterable.py @@ -23,7 +23,7 @@ class MapAsyncIterable: """ def __init__(self, iterable: AsyncIterable, callback: Callable) -> None: - self.iterable = iterable + self.iterator = iterable.__aiter__() self.callback = callback self._ageniter = self._agen() self.is_closed = False # used by unittests @@ -38,13 +38,13 @@ async def __anext__(self) -> Any: async def _agen(self) -> Any: try: - async for v in self.iterable: + async for v in self.iterator: result = self.callback(v) yield (await result) if isawaitable(result) else result finally: self.is_closed = True - if hasattr(self.iterable, "aclose"): - await self.iterable.aclose() + if hasattr(self.iterator, "aclose"): + await self.iterator.aclose() # This is not a standard method and is only used in unittests. Should be removed. async def athrow( From d0ee2252da051cf64880f650c50dbd8749fd0d96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 11 May 2023 13:05:07 +0000 Subject: [PATCH 3/3] Mapping method must be async --- src/graphql/execution/map_async_iterable.py | 17 +++++---- tests/execution/test_map_async_iterable.py | 42 +++++++++++++-------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py index b99c17ba..cc6715b6 100644 --- a/src/graphql/execution/map_async_iterable.py +++ b/src/graphql/execution/map_async_iterable.py @@ -1,8 +1,7 @@ from __future__ import annotations # Python < 3.10 -from inspect import isawaitable from types import TracebackType -from typing import Any, AsyncIterable, Callable, Optional, Type, Union +from typing import Any, AsyncIterable, Awaitable, Callable, Optional, Type, Union __all__ = ["MapAsyncIterable"] @@ -11,18 +10,21 @@ # The following is a class because its type is checked in the code. # otherwise, it could be implemented as a simple async generator function + # noinspection PyAttributeOutsideInit class MapAsyncIterable: """Map an AsyncIterable over a callback function. Given an AsyncIterable and a callback function, return an AsyncIterator which - produces values mapped via calling the callback function. + produces values mapped via calling the callback async function. - When the resulting AsyncIterator is closed, the underlying AsyncIterable will also - be closed. + Similar to an AsyncGenerator, an `aclose()` method is provivde which + will close the underlying AsyncIterable be if it has an `aclose()` method. """ - def __init__(self, iterable: AsyncIterable, callback: Callable) -> None: + def __init__( + self, iterable: AsyncIterable[Any], callback: Callable[[Any], Awaitable[Any]] + ) -> None: self.iterator = iterable.__aiter__() self.callback = callback self._ageniter = self._agen() @@ -39,8 +41,7 @@ async def __anext__(self) -> Any: async def _agen(self) -> Any: try: async for v in self.iterator: - result = self.callback(v) - yield (await result) if isawaitable(result) else result + yield await self.callback(v) finally: self.is_closed = True if hasattr(self.iterator, "aclose"): diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index aab1e7fd..09a0b28d 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -18,6 +18,14 @@ async def anext(iterator): return await iterator.__anext__() +async def map_single(x): + return x + + +async def map_doubles(x): + return x + x + + def describe_map_async_iterable(): @mark.asyncio async def maps_over_async_generator(): @@ -26,7 +34,7 @@ async def source(): yield 2 yield 3 - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -48,7 +56,7 @@ async def __anext__(self): except IndexError: raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = MapAsyncIterable(Iterable(), map_doubles) values = [value async for value in doubles] @@ -62,7 +70,7 @@ async def source(): yield 2 yield 3 - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) values = [value async for value in doubles] @@ -91,7 +99,7 @@ async def source(): yield 2 yield 3 # pragma: no cover - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -119,7 +127,7 @@ async def __anext__(self): except IndexError: # pragma: no cover raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = MapAsyncIterable(Iterable(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -145,7 +153,7 @@ async def source(): yield "Done" yield "Last" # pragma: no cover - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -168,7 +176,7 @@ async def __anext__(self): except IndexError: # pragma: no cover raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = MapAsyncIterable(Iterable(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -194,7 +202,7 @@ def __aiter__(self): async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) + one = MapAsyncIterable(Iterable(), map_single) assert await anext(one) == 1 @@ -220,7 +228,7 @@ def __aiter__(self): async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) + one = MapAsyncIterable(Iterable(), map_single) assert await anext(one) == 1 @@ -247,7 +255,7 @@ async def source(): except Exception as e: yield e - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -262,7 +270,7 @@ async def source(): yield "Hello" raise RuntimeError("Goodbye") - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == "HelloHello" @@ -276,7 +284,7 @@ async def does_not_normally_map_over_externally_thrown_errors(): async def source(): yield "Hello" - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = MapAsyncIterable(source(), map_doubles) assert await anext(doubles) == "HelloHello" @@ -305,7 +313,7 @@ async def __anext__(self): raise StopAsyncIteration return self.counter - def double(x): + async def double(x): return x + x for iterable in source, Source: @@ -377,7 +385,11 @@ async def source(): yield 3 # pragma: no cover singles = source() - doubles = MapAsyncIterable(singles, lambda x: x * 2) + + async def double(x): + return x * 2 + + doubles = MapAsyncIterable(singles, double) result = await anext(doubles) assert result == 2 @@ -431,7 +443,7 @@ async def aclose(self): self.is_closed = True iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) # pragma: no cover exit + doubles = MapAsyncIterable(iterable, map_doubles) # pragma: no cover exit cancelled = False async def iterator_task():