Skip to content

Commit dd70a55

Browse files
committed
Add lazy execution prototype
This PR is a proof of concept of adding lazy evaluation support to graphql-core. It follows a similar API to [graphql-ruby](https://graphql-ruby.org/schema/lazy_execution.html) and it allows the developer to define "lazy" type to enable batching without using asyncio. The tests illustrate how this can be used to enable a dataloader pattern. N.B. The DeferredValue object is very similar in functionality to the Promise library from graphql-core v2. I decided to reimplement a subset rather than use it directly though because it's scope is bigger than what I needed. It's a purely internal implementation detail though and can be replaced in future.
1 parent 1560449 commit dd70a55

File tree

3 files changed

+540
-1
lines changed

3 files changed

+540
-1
lines changed

src/graphql/execution/execute.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Path,
3838
Undefined,
3939
)
40+
from ..utilities.deferred_value import DeferredValue, deferred_dict, deferred_list
4041
from ..type import (
4142
GraphQLAbstractType,
4243
GraphQLField,
@@ -222,6 +223,11 @@ def __init__(
222223
self.is_awaitable = is_awaitable
223224
self._subfields_cache: Dict[Tuple, Dict[str, List[FieldNode]]] = {}
224225

226+
self._deferred_values: List[Tuple[DeferredValue, Any]] = []
227+
228+
def is_lazy(self, value: Any) -> bool:
229+
return False
230+
225231
@classmethod
226232
def build(
227233
cls,
@@ -350,12 +356,25 @@ def execute_operation(
350356

351357
path = None
352358

353-
return (
359+
result = (
354360
self.execute_fields_serially
355361
if operation.operation == OperationType.MUTATION
356362
else self.execute_fields
357363
)(root_type, root_value, path, root_fields)
358364

365+
while len(self._deferred_values) > 0:
366+
for d in list(self._deferred_values):
367+
self._deferred_values.remove(d)
368+
res = d[1].get()
369+
d[0].resolve(res)
370+
371+
if isinstance(result, DeferredValue):
372+
if result.is_rejected:
373+
raise cast(Exception, result.reason)
374+
return result.value
375+
376+
return result
377+
359378
def execute_fields_serially(
360379
self,
361380
parent_type: GraphQLObjectType,
@@ -432,6 +451,7 @@ def execute_fields(
432451
is_awaitable = self.is_awaitable
433452
awaitable_fields: List[str] = []
434453
append_awaitable = awaitable_fields.append
454+
contains_deferred = False
435455
for response_name, field_nodes in fields.items():
436456
field_path = Path(path, response_name, parent_type.name)
437457
result = self.execute_field(
@@ -441,6 +461,11 @@ def execute_fields(
441461
results[response_name] = result
442462
if is_awaitable(result):
443463
append_awaitable(response_name)
464+
if isinstance(result, DeferredValue):
465+
contains_deferred = True
466+
467+
if contains_deferred:
468+
return deferred_dict(results)
444469

445470
# If there are no coroutines, we can just return the object
446471
if not awaitable_fields:
@@ -634,6 +659,23 @@ def complete_value(
634659
if result is None or result is Undefined:
635660
return None
636661

662+
if self.is_lazy(result):
663+
def handle_resolve(resolved: Any) -> Any:
664+
return self.complete_value(
665+
return_type, field_nodes, info, path, resolved
666+
)
667+
668+
def handle_error(raw_error: Exception) -> None:
669+
raise raw_error
670+
671+
deferred = DeferredValue()
672+
self._deferred_values.append((
673+
deferred, result
674+
))
675+
676+
completed = deferred.then(handle_resolve, handle_error)
677+
return completed
678+
637679
# If field type is List, complete each item in the list with inner type
638680
if is_list_type(return_type):
639681
return self.complete_list_value(
@@ -705,6 +747,7 @@ async def async_iterable_to_list(
705747
append_awaitable = awaitable_indices.append
706748
completed_results: List[Any] = []
707749
append_result = completed_results.append
750+
contains_deferred = False
708751
for index, item in enumerate(result):
709752
# No need to modify the info object containing the path, since from here on
710753
# it is not ever accessed by resolver functions.
@@ -746,6 +789,9 @@ async def await_completed(item: Any, item_path: Path) -> Any:
746789
return None
747790

748791
completed_item = await_completed(completed_item, item_path)
792+
if isinstance(completed_item, DeferredValue):
793+
contains_deferred = True
794+
749795
except Exception as raw_error:
750796
error = located_error(raw_error, field_nodes, item_path.as_list())
751797
self.handle_field_error(error, item_type)
@@ -755,6 +801,9 @@ async def await_completed(item: Any, item_path: Path) -> Any:
755801
append_awaitable(index)
756802
append_result(completed_item)
757803

804+
if contains_deferred is True:
805+
return deferred_list(completed_results)
806+
758807
if not awaitable_indices:
759808
return completed_results
760809

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from typing import Any, Optional, List, Callable, cast, Dict
2+
3+
4+
OnSuccessCallback = Callable[[Any], None]
5+
OnErrorCallback = Callable[[Exception], None]
6+
7+
8+
class DeferredValue:
9+
PENDING = -1
10+
REJECTED = 0
11+
RESOLVED = 1
12+
13+
_value: Optional[Any]
14+
_reason: Optional[Exception]
15+
_callbacks: List[OnSuccessCallback]
16+
_errbacks: List[OnErrorCallback]
17+
18+
def __init__(
19+
self,
20+
on_complete: Optional[OnSuccessCallback] = None,
21+
on_error: Optional[OnErrorCallback] = None,
22+
):
23+
self._state = self.PENDING
24+
self._value = None
25+
self._reason = None
26+
if on_complete:
27+
self._callbacks = [on_complete]
28+
else:
29+
self._callbacks = []
30+
if on_error:
31+
self._errbacks = [on_error]
32+
else:
33+
self._errbacks = []
34+
35+
def resolve(self, value: Any) -> None:
36+
if self._state != DeferredValue.PENDING:
37+
return
38+
39+
if isinstance(value, DeferredValue):
40+
value.add_callback(self.resolve)
41+
value.add_errback(self.reject)
42+
return
43+
44+
self._value = value
45+
self._state = self.RESOLVED
46+
47+
callbacks = self._callbacks
48+
self._callbacks = []
49+
for callback in callbacks:
50+
try:
51+
callback(value)
52+
except Exception:
53+
# Ignore errors in callbacks
54+
pass
55+
56+
def reject(self, reason: Exception) -> None:
57+
if self._state != DeferredValue.PENDING:
58+
return
59+
60+
self._reason = reason
61+
self._state = self.REJECTED
62+
63+
errbacks = self._errbacks
64+
self._errbacks = []
65+
for errback in errbacks:
66+
try:
67+
errback(reason)
68+
except Exception:
69+
# Ignore errors in errback
70+
pass
71+
72+
def then(
73+
self,
74+
on_complete: Optional[OnSuccessCallback] = None,
75+
on_error: Optional[OnErrorCallback] = None,
76+
) -> "DeferredValue":
77+
ret = DeferredValue()
78+
79+
def call_and_resolve(v: Any) -> None:
80+
try:
81+
if on_complete:
82+
ret.resolve(on_complete(v))
83+
else:
84+
ret.resolve(v)
85+
except Exception as e:
86+
ret.reject(e)
87+
88+
def call_and_reject(r: Exception) -> None:
89+
try:
90+
if on_error:
91+
ret.resolve(on_error(r))
92+
else:
93+
ret.reject(r)
94+
except Exception as e:
95+
ret.reject(e)
96+
97+
self.add_callback(call_and_resolve)
98+
self.add_errback(call_and_resolve)
99+
100+
return ret
101+
102+
def add_callback(self, callback: OnSuccessCallback) -> None:
103+
if self._state == self.PENDING:
104+
self._callbacks.append(callback)
105+
return
106+
107+
if self._state == self.RESOLVED:
108+
callback(self._value)
109+
110+
def add_errback(self, callback: OnErrorCallback) -> None:
111+
if self._state == self.PENDING:
112+
self._errbacks.append(callback)
113+
return
114+
115+
if self._state == self.REJECTED:
116+
callback(cast(Exception, self._reason))
117+
118+
@property
119+
def is_resolved(self) -> bool:
120+
return self._state == self.RESOLVED
121+
122+
@property
123+
def is_rejected(self) -> bool:
124+
return self._state == self.REJECTED
125+
126+
@property
127+
def value(self) -> Any:
128+
return self._value
129+
130+
@property
131+
def reason(self) -> Optional[Exception]:
132+
return self._reason
133+
134+
135+
def deferred_dict(m: Dict[str, Any]) -> DeferredValue:
136+
"""
137+
A special function that takes a dictionary of deferred values
138+
and turns them into a deferred value that will ultimately resolve
139+
into a dictionary of values.
140+
"""
141+
if len(m) == 0:
142+
raise TypeError("Empty dict")
143+
144+
ret = DeferredValue()
145+
146+
plain_values = {
147+
key: value for key, value in m.items() if not isinstance(value, DeferredValue)
148+
}
149+
deferred_values = {
150+
key: value for key, value in m.items() if isinstance(value, DeferredValue)
151+
}
152+
153+
count = len(deferred_values)
154+
155+
def handle_success(_: Any) -> None:
156+
nonlocal count
157+
count -= 1
158+
if count == 0:
159+
value = plain_values
160+
161+
for k, p in deferred_values.items():
162+
value[k] = p.value
163+
164+
ret.resolve(value)
165+
166+
for p in deferred_values.values():
167+
p.add_callback(handle_success)
168+
p.add_errback(ret.reject)
169+
170+
return ret
171+
172+
173+
def deferred_list(l: List[Any]) -> DeferredValue:
174+
"""
175+
A special function that takes a list of deferred values
176+
and turns them into a deferred value for a list of values.
177+
"""
178+
if len(l) == 0:
179+
raise TypeError("Empty list")
180+
181+
ret = DeferredValue()
182+
183+
plain_values = {}
184+
deferred_values = {}
185+
for index, value in enumerate(l):
186+
if isinstance(value, DeferredValue):
187+
deferred_values[index] = value
188+
else:
189+
plain_values[index] = value
190+
191+
count = len(deferred_values)
192+
193+
def handle_success(_: Any) -> None:
194+
nonlocal count
195+
count -= 1
196+
if count == 0:
197+
values = []
198+
199+
for k in sorted(list(plain_values.keys()) + list(deferred_values.keys())):
200+
value = plain_values.get(k, None)
201+
if not value:
202+
value = deferred_values[k].value
203+
values.append(value)
204+
ret.resolve(values)
205+
206+
for p in l:
207+
p.add_callback(handle_success)
208+
p.add_errback(ret.reject)
209+
210+
return ret

0 commit comments

Comments
 (0)