Skip to content

Commit 2012613

Browse files
committed
Rewrite MapAsyncIterable using async generator semantics
iterate over and aclose() the iterator Mapping method must be async turn MapAsyncIterable into an AsyncGenerator
1 parent 62749e5 commit 2012613

File tree

8 files changed

+92
-216
lines changed

8 files changed

+92
-216
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
Middleware
152152
asyncio.events.AbstractEventLoop
153153
graphql.execution.collect_fields.FieldsAndPatches
154-
graphql.execution.map_async_iterable.MapAsyncIterable
154+
graphql.execution.map_async_iterable.map_async_iterable
155155
graphql.execution.Middleware
156156
graphql.execution.execute.DeferredFragmentRecord
157157
graphql.execution.execute.ExperimentalExecuteMultipleResults

docs/modules/execution.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ Execution
5757

5858
.. autofunction:: create_source_event_stream
5959

60-
.. autoclass:: MapAsyncIterable
61-
6260
.. autoclass:: Middleware
6361

6462
.. autoclass:: MiddlewareManager

src/graphql/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@
439439
# Subscription
440440
subscribe,
441441
create_source_event_stream,
442-
MapAsyncIterable,
442+
map_async_iterable,
443443
# Middleware
444444
Middleware,
445445
MiddlewareManager,
@@ -707,7 +707,7 @@
707707
"MiddlewareManager",
708708
"subscribe",
709709
"create_source_event_stream",
710-
"MapAsyncIterable",
710+
"map_async_iterable",
711711
"validate",
712712
"ValidationContext",
713713
"ValidationRule",

src/graphql/execution/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
FormattedIncrementalResult,
3333
Middleware,
3434
)
35-
from .map_async_iterable import MapAsyncIterable
35+
from .map_async_iterable import map_async_iterable
3636
from .middleware import MiddlewareManager
3737
from .values import get_argument_values, get_directive_values, get_variable_values
3838

@@ -62,7 +62,7 @@
6262
"FormattedIncrementalDeferResult",
6363
"FormattedIncrementalStreamResult",
6464
"FormattedIncrementalResult",
65-
"MapAsyncIterable",
65+
"map_async_iterable",
6666
"Middleware",
6767
"MiddlewareManager",
6868
"get_argument_values",

src/graphql/execution/execute.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
)
7171
from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields
7272
from .flatten_async_iterable import flatten_async_iterable
73-
from .map_async_iterable import MapAsyncIterable
73+
from .map_async_iterable import map_async_iterable
7474
from .middleware import MiddlewareManager
7575
from .values import get_argument_values, get_directive_values, get_variable_values
7676

@@ -1654,7 +1654,7 @@ async def callback(payload: Any) -> AsyncGenerator:
16541654
await result if isawaitable(result) else result # type: ignore
16551655
)
16561656

1657-
return flatten_async_iterable(MapAsyncIterable(result_or_stream, callback))
1657+
return flatten_async_iterable(map_async_iterable(result_or_stream, callback))
16581658

16591659
def execute_deferred_fragment(
16601660
self,
@@ -2319,18 +2319,20 @@ def subscribe(
23192319
if isinstance(result, ExecutionResult):
23202320
return result
23212321
if isinstance(result, AsyncIterable):
2322-
return MapAsyncIterable(result, ensure_single_execution_result)
2322+
return map_async_iterable(result, ensure_single_execution_result)
23232323

23242324
async def await_result() -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
23252325
result_or_iterable = await result # type: ignore
23262326
if isinstance(result_or_iterable, AsyncIterable):
2327-
return MapAsyncIterable(result_or_iterable, ensure_single_execution_result)
2327+
return map_async_iterable(
2328+
result_or_iterable, ensure_single_execution_result
2329+
)
23282330
return result_or_iterable
23292331

23302332
return await_result()
23312333

23322334

2333-
def ensure_single_execution_result(
2335+
async def ensure_single_execution_result(
23342336
result: Union[
23352337
ExecutionResult,
23362338
InitialIncrementalExecutionResult,
Lines changed: 16 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,26 @@
11
from __future__ import annotations # Python < 3.10
22

3-
from asyncio import CancelledError, Event, Task, ensure_future, wait
4-
from concurrent.futures import FIRST_COMPLETED
5-
from inspect import isasyncgen, isawaitable
6-
from types import TracebackType
7-
from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast
3+
from typing import Any, AsyncIterable, Awaitable, Callable
84

95

10-
__all__ = ["MapAsyncIterable"]
6+
__all__ = ["map_async_iterable"]
117

128

13-
# noinspection PyAttributeOutsideInit
14-
class MapAsyncIterable:
9+
async def map_async_iterable(
10+
iterable: AsyncIterable[Any], callback: Callable[[Any], Awaitable[Any]]
11+
) -> None:
1512
"""Map an AsyncIterable over a callback function.
1613
17-
Given an AsyncIterable and a callback function, return an AsyncIterator which
18-
produces values mapped via calling the callback function.
19-
20-
When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
21-
be closed.
14+
Given an AsyncIterable and an async callback callable, return an AsyncGenerator
15+
which produces values mapped via calling the callback.
16+
If the inner iterator supports an `aclose()` method, it will be called when
17+
the generator finishes or closes.
2218
"""
2319

24-
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
25-
self.iterator = iterable.__aiter__()
26-
self.callback = callback
27-
self._close_event = Event()
28-
29-
def __aiter__(self) -> MapAsyncIterable:
30-
"""Get the iterator object."""
31-
return self
32-
33-
async def __anext__(self) -> Any:
34-
"""Get the next value of the iterator."""
35-
if self.is_closed:
36-
if not isasyncgen(self.iterator):
37-
raise StopAsyncIteration
38-
value = await self.iterator.__anext__()
39-
else:
40-
aclose = ensure_future(self._close_event.wait())
41-
anext = ensure_future(self.iterator.__anext__())
42-
43-
try:
44-
pending: Set[Task] = (
45-
await wait([aclose, anext], return_when=FIRST_COMPLETED)
46-
)[1]
47-
except CancelledError:
48-
# cancel underlying tasks and close
49-
aclose.cancel()
50-
anext.cancel()
51-
await self.aclose()
52-
raise # re-raise the cancellation
53-
54-
for task in pending:
55-
task.cancel()
56-
57-
if aclose.done():
58-
raise StopAsyncIteration
59-
60-
error = anext.exception()
61-
if error:
62-
raise error
63-
64-
value = anext.result()
65-
66-
result = self.callback(value)
67-
68-
return await result if isawaitable(result) else result
69-
70-
async def athrow(
71-
self,
72-
type_: Union[BaseException, Type[BaseException]],
73-
value: Optional[BaseException] = None,
74-
traceback: Optional[TracebackType] = None,
75-
) -> None:
76-
"""Throw an exception into the asynchronous iterator."""
77-
if self.is_closed:
78-
return
79-
athrow = getattr(self.iterator, "athrow", None)
80-
if athrow:
81-
await athrow(type_, value, traceback)
82-
else:
83-
await self.aclose()
84-
if value is None:
85-
if traceback is None:
86-
raise type_
87-
value = (
88-
type_
89-
if isinstance(value, BaseException)
90-
else cast(Type[BaseException], type_)()
91-
)
92-
if traceback is not None:
93-
value = value.with_traceback(traceback)
94-
raise value
95-
96-
async def aclose(self) -> None:
97-
"""Close the iterator."""
98-
if not self.is_closed:
99-
aclose = getattr(self.iterator, "aclose", None)
100-
if aclose:
101-
try:
102-
await aclose()
103-
except RuntimeError:
104-
pass
105-
self.is_closed = True
106-
107-
@property
108-
def is_closed(self) -> bool:
109-
"""Check whether the iterator is closed."""
110-
return self._close_event.is_set()
111-
112-
@is_closed.setter
113-
def is_closed(self, value: bool) -> None:
114-
"""Mark the iterator as closed."""
115-
if value:
116-
self._close_event.set()
117-
else:
118-
self._close_event.clear()
20+
aiter = iterable.__aiter__()
21+
try:
22+
async for element in aiter:
23+
yield await callback(element)
24+
finally:
25+
if hasattr(aiter, "aclose"):
26+
await aiter.aclose()

tests/execution/test_customize.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from inspect import isasyncgen
2+
13
from pytest import mark
24

3-
from graphql.execution import ExecutionContext, MapAsyncIterable, execute, subscribe
5+
from graphql.execution import ExecutionContext, execute, subscribe
46
from graphql.language import parse
57
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
68

@@ -77,7 +79,7 @@ async def custom_foo():
7779
root_value=Root(),
7880
subscribe_field_resolver=lambda root, _info: root.custom_foo(),
7981
)
80-
assert isinstance(subscription, MapAsyncIterable)
82+
assert isasyncgen(subscription)
8183

8284
assert await anext(subscription) == (
8385
{"foo": "FooValue"},
@@ -121,6 +123,6 @@ def resolve_foo(message, _info):
121123
context_value={},
122124
execution_context_class=TestExecutionContext,
123125
)
124-
assert isinstance(subscription, MapAsyncIterable)
126+
assert isasyncgen(subscription)
125127

126128
assert await anext(subscription) == ({"foo": "bar"}, None)

0 commit comments

Comments
 (0)