1
1
from __future__ import annotations # Python < 3.10
2
2
3
- from asyncio import CancelledError , Event , Task , ensure_future , wait
4
- from concurrent .futures import FIRST_COMPLETED
5
- from inspect import isasyncgen , isawaitable
3
+ from inspect import isawaitable
6
4
from types import TracebackType
7
- from typing import Any , AsyncIterable , Callable , Optional , Set , Type , Union , cast
5
+ from typing import Any , AsyncIterable , Callable , Optional , Type , Union
8
6
9
7
10
8
__all__ = ["MapAsyncIterable" ]
11
9
12
10
11
+ # The following is a class because its type is checked in the code.
12
+ # otherwise, it could be implemented as a simple async generator function
13
+
13
14
# noinspection PyAttributeOutsideInit
14
15
class MapAsyncIterable :
15
16
"""Map an AsyncIterable over a callback function.
@@ -22,97 +23,39 @@ class MapAsyncIterable:
22
23
"""
23
24
24
25
def __init__ (self , iterable : AsyncIterable , callback : Callable ) -> None :
25
- self .iterator = iterable . __aiter__ ()
26
+ self .iterable = iterable
26
27
self .callback = callback
27
- self ._close_event = Event ()
28
+ self ._ageniter = self ._agen ()
29
+ self .is_closed = False # used by unittests
28
30
29
31
def __aiter__ (self ) -> MapAsyncIterable :
30
32
"""Get the iterator object."""
31
33
return self
32
34
33
35
async def __anext__ (self ) -> Any :
34
36
"""Get the next value of the iterator."""
35
- if self .is_closed :
36
- if not isasyncgen (self .iterator ):
37
- raise StopAsyncIteration
38
- value = await self .iterator .__anext__ ()
39
- else :
40
- aclose = ensure_future (self ._close_event .wait ())
41
- anext = ensure_future (self .iterator .__anext__ ())
42
-
43
- try :
44
- pending : Set [Task ] = (
45
- await wait ([aclose , anext ], return_when = FIRST_COMPLETED )
46
- )[1 ]
47
- except CancelledError :
48
- # cancel underlying tasks and close
49
- aclose .cancel ()
50
- anext .cancel ()
51
- await self .aclose ()
52
- raise # re-raise the cancellation
53
-
54
- for task in pending :
55
- task .cancel ()
56
-
57
- if aclose .done ():
58
- raise StopAsyncIteration
59
-
60
- error = anext .exception ()
61
- if error :
62
- raise error
63
-
64
- value = anext .result ()
65
-
66
- result = self .callback (value )
67
-
68
- return await result if isawaitable (result ) else result
37
+ return await self ._ageniter .__anext__ ()
38
+
39
+ async def _agen (self ) -> Any :
40
+ try :
41
+ async for v in self .iterable :
42
+ result = self .callback (v )
43
+ yield (await result ) if isawaitable (result ) else result
44
+ finally :
45
+ self .is_closed = True
46
+ if hasattr (self .iterable , "aclose" ):
47
+ await self .iterable .aclose ()
69
48
49
+ # This is not a standard method and is only used in unittests. Should be removed.
70
50
async def athrow (
71
51
self ,
72
52
type_ : Union [BaseException , Type [BaseException ]],
73
53
value : Optional [BaseException ] = None ,
74
54
traceback : Optional [TracebackType ] = None ,
75
55
) -> None :
76
56
"""Throw an exception into the asynchronous iterator."""
77
- if self .is_closed :
78
- return
79
- athrow = getattr (self .iterator , "athrow" , None )
80
- if athrow :
81
- await athrow (type_ , value , traceback )
82
- else :
83
- await self .aclose ()
84
- if value is None :
85
- if traceback is None :
86
- raise type_
87
- value = (
88
- type_
89
- if isinstance (value , BaseException )
90
- else cast (Type [BaseException ], type_ )()
91
- )
92
- if traceback is not None :
93
- value = value .with_traceback (traceback )
94
- raise value
57
+ await self ._ageniter .athrow (type_ , value , traceback )
95
58
96
59
async def aclose (self ) -> None :
97
60
"""Close the iterator."""
98
- if not self .is_closed :
99
- aclose = getattr (self .iterator , "aclose" , None )
100
- if aclose :
101
- try :
102
- await aclose ()
103
- except RuntimeError :
104
- pass
105
- self .is_closed = True
106
-
107
- @property
108
- def is_closed (self ) -> bool :
109
- """Check whether the iterator is closed."""
110
- return self ._close_event .is_set ()
111
-
112
- @is_closed .setter
113
- def is_closed (self , value : bool ) -> None :
114
- """Mark the iterator as closed."""
115
- if value :
116
- self ._close_event .set ()
117
- else :
118
- self ._close_event .clear ()
61
+ await self ._ageniter .aclose ()
0 commit comments