Skip to content

Simplify MapAsyncIterable using async generator semantics #197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 25 additions & 81 deletions src/graphql/execution/map_async_iterable.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,62 @@
from __future__ import annotations # Python < 3.10

from asyncio import CancelledError, Event, Task, ensure_future, wait
from concurrent.futures import FIRST_COMPLETED
from inspect import isasyncgen, isawaitable
from types import TracebackType
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast
from typing import Any, AsyncIterable, Awaitable, Callable, Optional, Type, Union


__all__ = ["MapAsyncIterable"]


# The following is a class because its type is checked in the code.
# otherwise, it could be implemented as a simple async generator function


# noinspection PyAttributeOutsideInit
class MapAsyncIterable:
"""Map an AsyncIterable over a callback function.

Given an AsyncIterable and a callback function, return an AsyncIterator which
produces values mapped via calling the callback function.
produces values mapped via calling the callback async function.

When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
be closed.
Similar to an AsyncGenerator, an `aclose()` method is provivde which
will close the underlying AsyncIterable be if it has an `aclose()` method.
"""

def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
def __init__(
self, iterable: AsyncIterable[Any], callback: Callable[[Any], Awaitable[Any]]
) -> None:
self.iterator = iterable.__aiter__()
self.callback = callback
self._close_event = Event()
self._ageniter = self._agen()
self.is_closed = False # used by unittests

def __aiter__(self) -> MapAsyncIterable:
"""Get the iterator object."""
return self

async def __anext__(self) -> Any:
"""Get the next value of the iterator."""
if self.is_closed:
if not isasyncgen(self.iterator):
raise StopAsyncIteration
value = await self.iterator.__anext__()
else:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())

try:
pending: Set[Task] = (
await wait([aclose, anext], return_when=FIRST_COMPLETED)
)[1]
except CancelledError:
# cancel underlying tasks and close
aclose.cancel()
anext.cancel()
await self.aclose()
raise # re-raise the cancellation

for task in pending:
task.cancel()

if aclose.done():
raise StopAsyncIteration

error = anext.exception()
if error:
raise error
return await self._ageniter.__anext__()

value = anext.result()

result = self.callback(value)

return await result if isawaitable(result) else result
async def _agen(self) -> Any:
try:
async for v in self.iterator:
yield await self.callback(v)
finally:
self.is_closed = True
if hasattr(self.iterator, "aclose"):
await self.iterator.aclose()

# This is not a standard method and is only used in unittests. Should be removed.
async def athrow(
self,
type_: Union[BaseException, Type[BaseException]],
value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
"""Throw an exception into the asynchronous iterator."""
if self.is_closed:
return
athrow = getattr(self.iterator, "athrow", None)
if athrow:
await athrow(type_, value, traceback)
else:
await self.aclose()
if value is None:
if traceback is None:
raise type_
value = (
type_
if isinstance(value, BaseException)
else cast(Type[BaseException], type_)()
)
if traceback is not None:
value = value.with_traceback(traceback)
raise value
await self._ageniter.athrow(type_, value, traceback)

async def aclose(self) -> None:
"""Close the iterator."""
if not self.is_closed:
aclose = getattr(self.iterator, "aclose", None)
if aclose:
try:
await aclose()
except RuntimeError:
pass
self.is_closed = True

@property
def is_closed(self) -> bool:
"""Check whether the iterator is closed."""
return self._close_event.is_set()

@is_closed.setter
def is_closed(self, value: bool) -> None:
"""Mark the iterator as closed."""
if value:
self._close_event.set()
else:
self._close_event.clear()
await self._ageniter.aclose()
Loading