Skip to content

Commit 8f66d13

Browse files
Turn MapAsyncIterable into an AsyncGenerator (#199)
1 parent a71f6a9 commit 8f66d13

12 files changed

+160
-604
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
Middleware
152152
asyncio.events.AbstractEventLoop
153153
graphql.execution.collect_fields.FieldsAndPatches
154-
graphql.execution.map_async_iterable.MapAsyncIterable
154+
graphql.execution.map_async_iterable.map_async_iterable
155155
graphql.execution.Middleware
156156
graphql.execution.execute.DeferredFragmentRecord
157157
graphql.execution.execute.ExperimentalIncrementalExecutionResults

docs/modules/execution.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ Execution
5757

5858
.. autofunction:: create_source_event_stream
5959

60-
.. autoclass:: MapAsyncIterable
61-
6260
.. autoclass:: Middleware
6361

6462
.. autoclass:: MiddlewareManager

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "graphql-core"
33
version = "3.3.0a2"
4-
description = """
4+
description = """\
55
GraphQL-core is a Python port of GraphQL.js,\
66
the JavaScript reference implementation for GraphQL."""
77
license = "MIT"

src/graphql/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@
450450
# Subscription
451451
subscribe,
452452
create_source_event_stream,
453-
MapAsyncIterable,
453+
map_async_iterable,
454454
# Middleware
455455
Middleware,
456456
MiddlewareManager,
@@ -729,7 +729,7 @@
729729
"MiddlewareManager",
730730
"subscribe",
731731
"create_source_event_stream",
732-
"MapAsyncIterable",
732+
"map_async_iterable",
733733
"validate",
734734
"ValidationContext",
735735
"ValidationRule",

src/graphql/execution/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
FormattedIncrementalResult,
3131
Middleware,
3232
)
33-
from .map_async_iterable import MapAsyncIterable
33+
from .iterators import map_async_iterable
3434
from .middleware import MiddlewareManager
3535
from .values import get_argument_values, get_directive_values, get_variable_values
3636

@@ -58,7 +58,7 @@
5858
"FormattedIncrementalDeferResult",
5959
"FormattedIncrementalStreamResult",
6060
"FormattedIncrementalResult",
61-
"MapAsyncIterable",
61+
"map_async_iterable",
6262
"Middleware",
6363
"MiddlewareManager",
6464
"get_argument_values",

src/graphql/execution/execute.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@
7171
is_object_type,
7272
)
7373
from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields
74-
from .flatten_async_iterable import flatten_async_iterable
75-
from .map_async_iterable import MapAsyncIterable
74+
from .iterators import flatten_async_iterable, map_async_iterable
7675
from .middleware import MiddlewareManager
7776
from .values import get_argument_values, get_directive_values, get_variable_values
7877

@@ -1650,7 +1649,7 @@ async def callback(payload: Any) -> AsyncGenerator:
16501649
await result if isawaitable(result) else result # type: ignore
16511650
)
16521651

1653-
return flatten_async_iterable(MapAsyncIterable(result_or_stream, callback))
1652+
return flatten_async_iterable(map_async_iterable(result_or_stream, callback))
16541653

16551654
def execute_deferred_fragment(
16561655
self,
@@ -2350,18 +2349,20 @@ def subscribe(
23502349
if isinstance(result, ExecutionResult):
23512350
return result
23522351
if isinstance(result, AsyncIterable):
2353-
return MapAsyncIterable(result, ensure_single_execution_result)
2352+
return map_async_iterable(result, ensure_single_execution_result)
23542353

23552354
async def await_result() -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
23562355
result_or_iterable = await result # type: ignore
23572356
if isinstance(result_or_iterable, AsyncIterable):
2358-
return MapAsyncIterable(result_or_iterable, ensure_single_execution_result)
2357+
return map_async_iterable(
2358+
result_or_iterable, ensure_single_execution_result
2359+
)
23592360
return result_or_iterable
23602361

23612362
return await_result()
23622363

23632364

2364-
def ensure_single_execution_result(
2365+
async def ensure_single_execution_result(
23652366
result: Union[
23662367
ExecutionResult,
23672368
InitialIncrementalExecutionResult,

src/graphql/execution/flatten_async_iterable.py renamed to src/graphql/execution/iterators.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from typing import AsyncGenerator, AsyncIterable, TypeVar, Union
1+
from __future__ import annotations # Python < 3.10
2+
3+
from typing import (
4+
Any,
5+
AsyncGenerator,
6+
AsyncIterable,
7+
Awaitable,
8+
Callable,
9+
TypeVar,
10+
Union,
11+
)
212

313

414
try:
@@ -15,10 +25,11 @@ async def aclosing(thing):
1525

1626

1727
T = TypeVar("T")
28+
V = TypeVar("V")
1829

1930
AsyncIterableOrGenerator = Union[AsyncGenerator[T, None], AsyncIterable[T]]
2031

21-
__all__ = ["flatten_async_iterable"]
32+
__all__ = ["flatten_async_iterable", "map_async_iterable"]
2233

2334

2435
async def flatten_async_iterable(
@@ -34,3 +45,23 @@ async def flatten_async_iterable(
3445
async with aclosing(sub_iterator) as items: # type: ignore
3546
async for item in items:
3647
yield item
48+
49+
50+
async def map_async_iterable(
51+
iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[V]]
52+
) -> AsyncGenerator[V, None]:
53+
"""Map an AsyncIterable over a callback function.
54+
55+
Given an AsyncIterable and an async callback callable, return an AsyncGenerator
56+
which produces values mapped via calling the callback.
57+
If the inner iterator supports an `aclose()` method, it will be called when
58+
the generator finishes or closes.
59+
"""
60+
61+
aiter = iterable.__aiter__()
62+
try:
63+
async for element in aiter:
64+
yield await callback(element)
65+
finally:
66+
if hasattr(aiter, "aclose"):
67+
await aiter.aclose()

src/graphql/execution/map_async_iterable.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

tests/execution/test_customize.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from inspect import isasyncgen
2+
13
from pytest import mark
24

3-
from graphql.execution import ExecutionContext, MapAsyncIterable, execute, subscribe
5+
from graphql.execution import ExecutionContext, execute, subscribe
46
from graphql.language import parse
57
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
68

@@ -77,7 +79,7 @@ async def custom_foo():
7779
root_value=Root(),
7880
subscribe_field_resolver=lambda root, _info: root.custom_foo(),
7981
)
80-
assert isinstance(subscription, MapAsyncIterable)
82+
assert isasyncgen(subscription)
8183

8284
assert await anext(subscription) == (
8385
{"foo": "FooValue"},
@@ -121,6 +123,6 @@ def resolve_foo(message, _info):
121123
context_value={},
122124
execution_context_class=TestExecutionContext,
123125
)
124-
assert isinstance(subscription, MapAsyncIterable)
126+
assert isasyncgen(subscription)
125127

126128
assert await anext(subscription) == ({"foo": "bar"}, None)

tests/execution/test_flatten_async_iterable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytest import mark, raises
44

5-
from graphql.execution.flatten_async_iterable import flatten_async_iterable
5+
from graphql.execution.iterators import flatten_async_iterable
66

77

88
try: # pragma: no cover

0 commit comments

Comments
 (0)